diff --git a/pyproject.toml b/pyproject.toml index f5b1a711..c963ea3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,9 @@ dependencies = [ "transformers>=4.44.2", "datasets>=3.0.0", "pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c", - "einops" + "einops", + "torchdata>=0.8.0", + "fsspec[gcs]>=2024.3.1", ] [project.optional-dependencies] diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py new file mode 100644 index 00000000..9bc0f003 --- /dev/null +++ b/src/zeroband/checkpoint.py @@ -0,0 +1,225 @@ +from dataclasses import dataclass +import multiprocessing +import os +import time +from typing import Any +from fsspec.generic import rsync as rsync_fsspec +import torch +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR +from torchdata.stateful_dataloader import StatefulDataLoader +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint.state_dict import ( + set_optimizer_state_dict, + set_model_state_dict, + get_model_state_dict, + get_optimizer_state_dict, + StateDictOptions, +) +from torch.distributed.checkpoint.stateful import Stateful +from zeroband.utils.logging import get_logger +import warnings +import logging + +from zeroband.utils.world_info import get_world_info + +## code inspired by torchtitan https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py + + +@dataclass +class TrainingProgress(Stateful): + total_tokens: int + outer_step: int + step: int + + def state_dict(self) -> dict[str, Any]: + return {"total_tokens": self.total_tokens, "outer_step": self.outer_step, "step": self.step} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.total_tokens = state_dict["total_tokens"] + self.outer_step = state_dict["outer_step"] + self.step = state_dict["step"] + + +class ModelWrapper(Stateful): + def __init__(self, model: nn.Module) -> None: + self.model = model + + def state_dict(self) -> dict[str, Any]: + return get_model_state_dict(self.model) + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + set_model_state_dict(model=self.model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) + + +class OptimizerWrapper(Stateful): + def __init__( + self, + model: nn.Module, + optim: torch.optim.Optimizer, + ) -> None: + self.model = model + self.optim = optim + + def state_dict(self) -> dict[str, Any]: + return get_optimizer_state_dict( + model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True) + ) + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + set_optimizer_state_dict( + model=self.model, optimizers=self.optim, optim_state_dict=state_dict, options=StateDictOptions(strict=False) + ) + + +class CkptManager: + """Its name CkptManager because I (sami) always misstyped chekcpoint. + + Checkpoint are saved in a folder with the following structure: + ckpt_path/ + step_0/ + _0_0.pt + _1_0.pt + ... + step_1/ + ... + """ + + def __init__( + self, + model: nn.Module, + optimizer: Optimizer, + scheduler: LambdaLR, + dataloader: StatefulDataLoader, + training_progress: TrainingProgress, + diloco_offloaded_param_list: list[nn.Parameter] | None, + diloco_offloaded_optimizer: Optimizer | None, + ): + self.model = ModelWrapper(model) + self.optimizer = OptimizerWrapper(model, optimizer) + self.scheduler = scheduler + self.dataloader = dataloader + self.training_progress = training_progress + + # states can only be stateful object, hence we need to wrap Model and Optimizer + self.states: dict[str, Stateful] = { + "model": self.model, + "optimizer": self.optimizer, + "scheduler": self.scheduler, + # "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader + "training_progress": self.training_progress, + } + + assert (diloco_offloaded_param_list is None) == ( + diloco_offloaded_optimizer is None + ), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values" + + self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed + # which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho + self.diloco_offloaded_param_list = diloco_offloaded_param_list + + if diloco_offloaded_optimizer is not None: + # even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state. + # main reason is that we actually don't a cpu model but just a list of cpu parameters. + self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer + + self._logger = get_logger() + + self.async_save_process: list[multiprocessing.Process] = [] + + def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None: + """ + Each rank will save the right shard of the model and optimizer. + + Saving is done inplace + """ + + time_start = time.perf_counter() + world_info = get_world_info() + + ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}") + if self.diloco_offloaded_optimizer: + # here we save model and offloaded optimizer on each diloco rank even tho they are the same + # this is done for two reasons: + # * if the nodes don't share a filesystem nor a remote path, they still save all of the data + # * its easier to implement and avoid race condition on the shared data. + ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}") + + catch_warning = self._logger.getEffectiveLevel() <= logging.INFO + + with warnings.catch_warnings(): + # pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907 + # we can ignore it if we are not logging in DEBUG mode + if catch_warning: + warnings.simplefilter("ignore") + + dcp.save(self.states, checkpoint_id=ckpt_path) + + ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk + with open(os.path.join(ckpt_path, f"__{world_info.local_rank}_0.pt"), "wb") as f: + torch.save({"data_loader": self.dataloader.state_dict()}, f) + + self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds") + + if remote_ckpt_path is not None: + self._async_save_remote(ckpt_path, remote_ckpt_path) + + def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str): + """asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install + specific libraries (e.g. s3fs) + """ + + def rsync(): + time_start = time.perf_counter() + self._logger.info(f"start pushing {ckpt_path} to {remote_ckpt_path} asynchronously") + rsync_fsspec(ckpt_path, destination=remote_ckpt_path) + self._logger.info( + f"finish pushing {ckpt_path} to {remote_ckpt_path} in {time.perf_counter() - time_start} seconds" + ) + + processes = multiprocessing.Process(target=rsync, daemon=True) + processes.start() + + self.async_save_process.append(processes) + + def wait_async_save_process(self): + """ + wait for all async save process to finish + """ + for process in self.async_save_process: + process.join() + + def _del__(self): + self.wait_async_save_process() + + def load(self, resume_ckpt_path: str) -> None: + """ + loading should be done after fsdp wrap and optimizer init. + Each rank will load the right shard of the model and optimizer. + All rank will load the global states (scheduler, step, total_tokens, dataloader). + + `resume_ckpt_path` should point to a specific step and not to the base ckpt folder. Example: `ckpt_path/step_100` + + Loading is done inplace + """ + time_start = time.perf_counter() + + world_info = get_world_info() + if self.diloco_offloaded_param_list is not None: + resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}") + + self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path) + + # since we don't load the param list from the state dict as its the same as the model one we just copy + if self.diloco_offloaded_param_list is not None: + for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()): + param_offloaded.data.copy_(param_model.data) + + ## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk + with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f: + rank_state_dict = torch.load(f) + + self.dataloader.load_state_dict(rank_state_dict["data_loader"]) + + self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds") diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 61a1a986..1a093d1c 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -4,6 +4,7 @@ import torch from torch.utils.data import DataLoader from torch.utils.data import IterableDataset +from torchdata.stateful_dataloader import StatefulDataLoader from datasets import load_dataset from datasets.distributed import split_dataset_by_node @@ -79,7 +80,7 @@ def tokenize_function(data): data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100) - return DataLoader( + return StatefulDataLoader( train_dataset, collate_fn=data_collator, batch_size=batch_size, diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 0d93d57a..9879a17e 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -26,6 +26,7 @@ from zeroband.models.llama import get_model from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger +from zeroband.checkpoint import CkptManager, TrainingProgress class DataConfig(BaseConfig): @@ -53,6 +54,13 @@ class TrainConfig(BaseConfig): log_model_hash: bool = False +class CkptConfig(BaseConfig): + path: str + interval: int + + remote_path: str | None = None # could be a s3 path + + class Config(BaseConfig): # main config name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "13B", "26B", "70B"] = "150M" @@ -67,6 +75,9 @@ class Config(BaseConfig): optim: OptimConfig = OptimConfig() train: TrainConfig + ckpt: CkptConfig | None = None + resume: str | None = None + def train(config: Config): sharding_strategy = get_sharding_strategy(config.train.sharding_strategy) @@ -78,6 +89,11 @@ def train(config: Config): assert batch_size % config.train.micro_bs == 0 gradient_accumulation_steps = batch_size // config.train.micro_bs + if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None: + assert ( + config.ckpt.interval % config.diloco.inner_steps == 0 + ), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step" + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) tokenizer.pad_token = "" # todo(sami): remove padding tokens once we have context stuffing @@ -128,10 +144,7 @@ def train(config: Config): use_orig_params=True, process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, ) - - if config.train.torch_compile: - model = torch.compile(model) - logger.debug("model compiled and fsdped") + logger.debug("model fsdped") # Setup optimizers inner_optimizer = torch.optim.AdamW( @@ -150,6 +163,27 @@ def train(config: Config): num_training_steps=config.optim.total_steps, ) + training_progress = TrainingProgress(total_tokens=0, outer_step=0, step=0) + + ckpt_manager = CkptManager( + model=model, + optimizer=inner_optimizer, + scheduler=scheduler, + dataloader=train_dataloader, + training_progress=training_progress, + diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, + diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, + ) + + if config.train.torch_compile: + # we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY + model = torch.compile(model) + logger.debug("model compiled") + + if config.resume is not None: + # all is inplace + ckpt_manager.load(resume_ckpt_path=config.resume) + model.train() if world_info.rank == 0: @@ -158,7 +192,6 @@ def train(config: Config): train_dataloader_iterator = iter(train_dataloader) - outer_step = 0 num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1 perf_counter = PerfCounter(window_size=10) @@ -166,9 +199,9 @@ def train(config: Config): while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs - logger.info(f"outer_step step: {outer_step}") + logger.info(f"outer_step step: {training_progress.outer_step}") - for inner_step in range(num_inner_steps): + for _inner_step in range(num_inner_steps): loss_batch = 0 for grad_acc_step in range(gradient_accumulation_steps): @@ -195,22 +228,30 @@ def train(config: Config): inner_optimizer.zero_grad() # logging - real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0 + training_progress.step += 1 inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) # syncing loss across all data parallel rank within a nodes - perf_counter.count_tokens(config.data.seq_length * config.optim.batch_size) + new_tokens = config.data.seq_length * config.optim.batch_size + perf_counter.count_tokens(new_tokens) + + if config.diloco is not None: + training_progress.total_tokens += new_tokens + else: + # we count the total tokens with respect to all diloco workers + # might need to tweak this as some worker might fail to join the all reduce later + training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size() metrics = { "Loss": loss_batch.item(), - "step": real_step, + "step": training_progress.step, "inner_lr": inner_lr, "Perplexity": torch.exp(loss_batch).item(), - "total_tokens": real_step * config.optim.batch_size * config.data.seq_length, + "total_tokens": training_progress.total_tokens, } - log = f"step: {real_step}, loss: {loss_batch.item():.4f}" + log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}" tokens_per_second = perf_counter.get_tokens_per_second() @@ -239,9 +280,17 @@ def train(config: Config): with FSDP.summon_full_params(model): logger.debug("Post diloco model: %s", get_module_signature(model)) - outer_step += 1 + training_progress.outer_step += 1 - if real_step >= config.optim.total_steps: + if ( + config.ckpt is not None + and training_progress.step > 0 + and training_progress.step % config.ckpt.interval == 0 + ): + # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway + ckpt_manager.save(config.ckpt.path, config.ckpt.remote_path) + + if training_progress.step >= config.optim.total_steps: # we only allow to break outisde of the inner loop. # This avoid ending the training in the middle of a the inner loop # Since ckpt strategy and all reduce is done at the outer loop level. @@ -250,6 +299,9 @@ def train(config: Config): if world_info.rank == 0: metric_logger.finish() + ckpt_manager.wait_async_save_process() + logger.info("Training finished, exiting ...") + if __name__ == "__main__": # Allow eager fallback during production so that that the training runs dont die diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 9b73f328..fcca5da2 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -27,6 +27,10 @@ def __init__(self): def __repr__(self): return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, nnodes={self.nnodes}, global_unique_id={self.global_unique_id}, global_addr={self.global_addr}, global_port={self.global_port}, global_world_size={self.global_world_size}, global_rank={self.global_rank})" + @property + def diloco_rank(self): + return self.global_rank + def get_world_info() -> WorldInfo: """ diff --git a/uv.lock b/uv.lock index f0ec766e..84348844 100644 --- a/uv.lock +++ b/uv.lock @@ -140,6 +140,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/21/5b6702a7f963e95456c0de2d495f67bf5fd62840ac655dc451586d23d39a/attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2", size = 63001 }, ] +[[package]] +name = "cachetools" +version = "5.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/38/a0f315319737ecf45b4319a8cd1f3a908e29d9277b46942263292115eee7/cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a", size = 27661 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/07/14f8ad37f2d12a5ce41206c21820d8cb6561b728e51fad4530dff0552a67/cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292", size = 9524 }, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -258,6 +267,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/52/45dab187f03d48c765b94db0464f5c10431756e47ae4cc6a8029a7d57a36/datasets-3.0.0-py3-none-any.whl", hash = "sha256:c23fefb6c953dcb1cd5f6deb6c502729c733ef98791e0c3f2d80c7ca2d9a01dd", size = 474265 }, ] +[[package]] +name = "decorator" +version = "5.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/0c/8d907af351aa16b42caae42f9d6aa37b900c67308052d10fdce809f8d952/decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330", size = 35016 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 }, +] + [[package]] name = "dill" version = "0.3.8" @@ -379,10 +397,31 @@ wheels = [ ] [package.optional-dependencies] +gcs = [ + { name = "gcsfs" }, +] http = [ { name = "aiohttp" }, ] +[[package]] +name = "gcsfs" +version = "2024.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "decorator" }, + { name = "fsspec" }, + { name = "google-auth" }, + { name = "google-auth-oauthlib" }, + { name = "google-cloud-storage" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7f/b1/c5ae16ad1d499f0cf10e3306f717eadae30dba64ec29236077b8fe661e7c/gcsfs-2024.6.1.tar.gz", hash = "sha256:e8858c7a893b2265e9bfce2fe270a024a2e348c74c23528801db388fc0224ed7", size = 79259 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/65/f467159d42a2ce4191f1f6ff319e75b5a14ab2cc080062dbf5821a80244c/gcsfs-2024.6.1-py2.py3-none-any.whl", hash = "sha256:13fd18095425e54e248870594fd155812723966b1bda3b102b3a5c44ec436a03", size = 34866 }, +] + [[package]] name = "gitdb" version = "4.0.11" @@ -407,6 +446,129 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/bd/cc3a402a6439c15c3d4294333e13042b915bbeab54edc457c723931fed3f/GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff", size = 207337 }, ] +[[package]] +name = "google-api-core" +version = "2.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/5c/31c1742a53b79c8a0c4757b5fae2e8ab9c519cbd7b98c587d4294e1d2d16/google_api_core-2.20.0.tar.gz", hash = "sha256:f74dff1889ba291a4b76c5079df0711810e2d9da81abfdc99957bc961c1eb28f", size = 152583 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/dc/6143f67acf5f30717c9e1b1c48fc04c0f59b869be046e6639d3f171640ae/google_api_core-2.20.0-py3-none-any.whl", hash = "sha256:ef0591ef03c30bb83f79b3d0575c3f31219001fc9c5cf37024d08310aeffed8a", size = 142162 }, +] + +[[package]] +name = "google-auth" +version = "2.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/37/c854a8b1b1020cf042db3d67577c6f84cd1e8ff6515e4f5498ae9e444ea5/google_auth-2.35.0.tar.gz", hash = "sha256:f4c64ed4e01e8e8b646ef34c018f8bf3338df0c8e37d8b3bba40e7f574a3278a", size = 267223 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/1f/3a72917afcb0d5cd842cbccb81bf7a8a7b45b4c66d8dc4556ccb3b016bfc/google_auth-2.35.0-py2.py3-none-any.whl", hash = "sha256:25df55f327ef021de8be50bad0dfd4a916ad0de96da86cd05661c9297723ad3f", size = 208968 }, +] + +[[package]] +name = "google-auth-oauthlib" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "requests-oauthlib" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cc/0f/1772edb8d75ecf6280f1c7f51cbcebe274e8b17878b382f63738fd96cee5/google_auth_oauthlib-1.2.1.tar.gz", hash = "sha256:afd0cad092a2eaa53cd8e8298557d6de1034c6cb4a740500b5357b648af97263", size = 24970 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/8e/22a28dfbd218033e4eeaf3a0533b2b54852b6530da0c0fe934f0cc494b29/google_auth_oauthlib-1.2.1-py2.py3-none-any.whl", hash = "sha256:2d58a27262d55aa1b87678c3ba7142a080098cbc2024f903c62355deb235d91f", size = 24930 }, +] + +[[package]] +name = "google-cloud-core" +version = "2.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b8/1f/9d1e0ba6919668608570418a9a51e47070ac15aeff64261fb092d8be94c0/google-cloud-core-2.4.1.tar.gz", hash = "sha256:9b7749272a812bde58fff28868d0c5e2f585b82f37e09a1f6ed2d4d10f134073", size = 35587 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/0f/2e2061e3fbcb9d535d5da3f58cc8de4947df1786fe6a1355960feb05a681/google_cloud_core-2.4.1-py2.py3-none-any.whl", hash = "sha256:a9e6a4422b9ac5c29f79a0ede9485473338e2ce78d91f2370c01e730eab22e61", size = 29233 }, +] + +[[package]] +name = "google-cloud-storage" +version = "2.18.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, + { name = "google-cloud-core" }, + { name = "google-crc32c" }, + { name = "google-resumable-media" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/b7/1554cdeb55d9626a4b8720746cba8119af35527b12e1780164f9ba0f659a/google_cloud_storage-2.18.2.tar.gz", hash = "sha256:aaf7acd70cdad9f274d29332673fcab98708d0e1f4dceb5a5356aaef06af4d99", size = 5532864 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/da/95db7bd4f0bd1644378ac1702c565c0210b004754d925a74f526a710c087/google_cloud_storage-2.18.2-py2.py3-none-any.whl", hash = "sha256:97a4d45c368b7d401ed48c4fdfe86e1e1cb96401c9e199e419d289e2c0370166", size = 130466 }, +] + +[[package]] +name = "google-crc32c" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/72/c3298da1a3773102359c5a78f20dae8925f5ea876e37354415f68594a6fb/google_crc32c-1.6.0.tar.gz", hash = "sha256:6eceb6ad197656a1ff49ebfbbfa870678c75be4344feb35ac1edf694309413dc", size = 14472 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/be/d7846cb50e17bf72a70ea2d8159478ac5de0f1170b10cac279f50079e78d/google_crc32c-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5bcc90b34df28a4b38653c36bb5ada35671ad105c99cfe915fb5bed7ad6924aa", size = 30267 }, + { url = "https://files.pythonhosted.org/packages/84/3b/29cadae166132e4991087a49dc88906a1d3d5ec22b80f63bc4bc7b6e0431/google_crc32c-1.6.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d9e9913f7bd69e093b81da4535ce27af842e7bf371cde42d1ae9e9bd382dc0e9", size = 30113 }, + { url = "https://files.pythonhosted.org/packages/18/a9/49a7b2c4b7cc69d15778a820734f9beb647b1b4cf1a629ca43e3d3a54c70/google_crc32c-1.6.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a184243544811e4a50d345838a883733461e67578959ac59964e43cca2c791e7", size = 37702 }, + { url = "https://files.pythonhosted.org/packages/4b/aa/52538cceddefc7c2d66c6bd59dfe67a50f65a4952f441f91049e4188eb57/google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:236c87a46cdf06384f614e9092b82c05f81bd34b80248021f729396a78e55d7e", size = 32847 }, + { url = "https://files.pythonhosted.org/packages/b1/2c/1928413d3faae74ae0d7bdba648cf36ed6b03328c562b47046af016b7249/google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebab974b1687509e5c973b5c4b8b146683e101e102e17a86bd196ecaa4d099fc", size = 37844 }, + { url = "https://files.pythonhosted.org/packages/d6/f4/f62fa405e442b37c5676973b759dd6e56cd8d58a5c78662912456526f716/google_crc32c-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:50cf2a96da226dcbff8671233ecf37bf6e95de98b2a2ebadbfdf455e6d05df42", size = 33444 }, + { url = "https://files.pythonhosted.org/packages/7d/14/ab47972ac79b6e7b03c8be3a7ef44b530a60e69555668dbbf08fc5692a98/google_crc32c-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f7a1fc29803712f80879b0806cb83ab24ce62fc8daf0569f2204a0cfd7f68ed4", size = 30267 }, + { url = "https://files.pythonhosted.org/packages/54/7d/738cb0d25ee55629e7d07da686decf03864a366e5e863091a97b7bd2b8aa/google_crc32c-1.6.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:40b05ab32a5067525670880eb5d169529089a26fe35dce8891127aeddc1950e8", size = 30112 }, + { url = "https://files.pythonhosted.org/packages/3e/6d/33ca50cbdeec09c31bb5dac277c90994edee975662a4c890bda7ffac90ef/google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e4b426c3702f3cd23b933436487eb34e01e00327fac20c9aebb68ccf34117d", size = 32861 }, + { url = "https://files.pythonhosted.org/packages/67/1e/4870896fc81ec77b1b5ebae7fdd680d5a4d40e19a4b6d724032f996ca77a/google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51c4f54dd8c6dfeb58d1df5e4f7f97df8abf17a36626a217f169893d1d7f3e9f", size = 32490 }, + { url = "https://files.pythonhosted.org/packages/00/9c/f5f5af3ddaa7a639d915f8f58b09bbb8d1db90ecd0459b62cd430eb9a4b6/google_crc32c-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:bb8b3c75bd157010459b15222c3fd30577042a7060e29d42dabce449c087f2b3", size = 33446 }, + { url = "https://files.pythonhosted.org/packages/cf/41/65a91657d6a8123c6c12f9aac72127b6ac76dda9e2ba1834026a842eb77c/google_crc32c-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ed767bf4ba90104c1216b68111613f0d5926fb3780660ea1198fc469af410e9d", size = 30268 }, + { url = "https://files.pythonhosted.org/packages/59/d0/ee743a267c7d5c4bb8bd865f7d4c039505f1c8a4b439df047fdc17be9769/google_crc32c-1.6.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:62f6d4a29fea082ac4a3c9be5e415218255cf11684ac6ef5488eea0c9132689b", size = 30113 }, + { url = "https://files.pythonhosted.org/packages/25/53/e5e449c368dd26ade5fb2bb209e046d4309ed0623be65b13f0ce026cb520/google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c87d98c7c4a69066fd31701c4e10d178a648c2cac3452e62c6b24dc51f9fcc00", size = 32995 }, + { url = "https://files.pythonhosted.org/packages/52/12/9bf6042d5b0ac8c25afed562fb78e51b0641474097e4139e858b45de40a5/google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5e7d2445d1a958c266bfa5d04c39932dc54093fa391736dbfdb0f1929c1fb3", size = 32614 }, + { url = "https://files.pythonhosted.org/packages/76/29/fc20f5ec36eac1eea0d0b2de4118c774c5f59c513f2a8630d4db6991f3e0/google_crc32c-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7aec8e88a3583515f9e0957fe4f5f6d8d4997e36d0f61624e70469771584c760", size = 33445 }, + { url = "https://files.pythonhosted.org/packages/e7/ff/ed48d136b65ddc61f5aef6261c58cd817c8cd60640b16680e5419fb17018/google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48abd62ca76a2cbe034542ed1b6aee851b6f28aaca4e6551b5599b6f3ef175cc", size = 28057 }, + { url = "https://files.pythonhosted.org/packages/14/fb/54deefe679b7d1c1cc81d83396fcf28ad1a66d213bddeb275a8d28665918/google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e311c64008f1f1379158158bb3f0c8d72635b9eb4f9545f8cf990c5668e59d", size = 27866 }, +] + +[[package]] +name = "google-resumable-media" +version = "2.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-crc32c" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/5a/0efdc02665dca14e0837b62c8a1a93132c264bd02054a15abb2218afe0ae/google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0", size = 2163099 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/35/b8d3baf8c46695858cb9d8835a53baa1eeb9906ddaf2f728a5f5b640fd1e/google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa", size = 81251 }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.65.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/3b/1599ceafa875ffb951480c8c74f4b77646a6b80e80970698f2aa93c216ce/googleapis_common_protos-1.65.0.tar.gz", hash = "sha256:334a29d07cddc3aa01dee4988f9afd9b2916ee2ff49d6b757155dc0d197852c0", size = 113657 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/08/49bfe7cf737952cc1a9c43e80cc258ed45dad7f183c5b8276fc94cb3862d/googleapis_common_protos-1.65.0-py2.py3-none-any.whl", hash = "sha256:2972e6c496f435b92590fd54045060867f3fe9be2c82ab148fc8885035479a63", size = 220890 }, +] + [[package]] name = "huggingface-hub" version = "0.24.6" @@ -818,6 +980,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] +[[package]] +name = "oauthlib" +version = "3.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/fa/fbf4001037904031639e6bfbfc02badfc7e12f137a8afa254df6c4c8a670/oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918", size = 177352 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/80/cab10959dc1faead58dc8384a781dfbf93cb4d33d50988f7a69f1b7c9bbe/oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", size = 151688 }, +] + [[package]] name = "packaging" version = "24.1" @@ -896,6 +1067,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/92/caae8c86e94681b42c246f0bca35c059a2f0529e5b92619f6aba4cf7e7b6/pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f", size = 204643 }, ] +[[package]] +name = "proto-plus" +version = "1.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/fc/e9a65cd52c1330d8d23af6013651a0bc50b6d76bcbdf91fae7cd19c68f29/proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445", size = 55942 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/6f/db31f0711c0402aa477257205ce7d29e86a75cb52cd19f7afb585f75cda0/proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12", size = 50080 }, +] + [[package]] name = "protobuf" version = "5.28.2" @@ -959,6 +1142,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/49/baafe2a964f663413be3bd1cf5c45ed98c5e42e804e2328e18f4570027c1/pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7", size = 25099235 }, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd", size = 181537 }, +] + [[package]] name = "pydantic" version = "2.9.1" @@ -1224,6 +1428,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, ] +[[package]] +name = "requests-oauthlib" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "oauthlib" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179 }, +] + [[package]] name = "rich" version = "13.8.1" @@ -1237,6 +1454,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/11/dadb85e2bd6b1f1ae56669c3e1f0410797f9605d752d68fb47b77f525b31/rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06", size = 241608 }, ] +[[package]] +name = "rsa" +version = "4.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, +] + [[package]] name = "ruff" version = "0.6.4" @@ -1525,6 +1754,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/30/8b6f77ea4ce84f015ee024b8dfef0dac289396254e8bfd493906d4cbb848/torch-2.4.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:72b484d5b6cec1a735bf3fa5a1c4883d01748698c5e9cfdbeb4ffab7c7987e0d", size = 62123443 }, ] +[[package]] +name = "torchdata" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, + { name = "torch" }, + { name = "urllib3" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/8a/3251c64214ab09d1c1756677f36e78f8cf0ce9dabb3a21386e78ef50540e/torchdata-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:082e27b2acb1768cb6a30ddd2f8d9c68e407164ce207194bf8bfa616d621a801", size = 4904801 }, + { url = "https://files.pythonhosted.org/packages/da/90/058fe345dfac8b50d2d0fdb421ce04c78c88b06a5f220dd8d64d424ccdbe/torchdata-0.8.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:44f7875a62f3fab86e2f8e5af92c4929f8f7390aa17bd697fdd0965723bc1e98", size = 2691733 }, + { url = "https://files.pythonhosted.org/packages/2f/54/d6f64a6e210ee50b68220d3b5564ffdda8bcc8d26c02a39a8a587caffe2f/torchdata-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:07e985d69c5692dda9181a8ef3e14c7f08b0460226f7cd4cf1c1bb0e6975700f", size = 1341187 }, + { url = "https://files.pythonhosted.org/packages/82/aa/4da6c725b03fb51c5a10405803308afd43970e66aad45e8cca872786ba1b/torchdata-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1635cecf4226fec8539c5d06ba764a48c41363ea0bbea09407ab379828527a8b", size = 4904783 }, + { url = "https://files.pythonhosted.org/packages/64/e8/c691e8e73dc6cbb09ba84ffb0341a6466d3184ff422cda07ebade3b929ef/torchdata-0.8.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:2d63d3fdcc68cf912c81709c8704b9cf435ba89bceed41a365e7362eb5740394", size = 2691483 }, + { url = "https://files.pythonhosted.org/packages/2c/f6/438a82c2f8d69114ef943c0b58f69f66ea5249bd7b2e4799d44f185f7797/torchdata-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8db8a7cb946e82983517cff94317f1898128cbfe4f48821d0c3509c0cdafa4c9", size = 1341021 }, + { url = "https://files.pythonhosted.org/packages/3e/f7/2d1cd02ebcca73ff151dd94b0a08d30808574d944a360470b52a89f0be4e/torchdata-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4e3f7efac3d8a4bd4efcb1427869c04043d0a0a019f9aa1eb381bd6c6b321e62", size = 4905186 }, + { url = "https://files.pythonhosted.org/packages/ea/94/d9ac51405d4259094dfa0a1dc3fa4ed2efe057d194873c9f1ba1881b06c9/torchdata-0.8.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:bb878e243e58526a5b3ac54583f7c029ad643a34ade798800c1878c83f1c36ee", size = 2691660 }, + { url = "https://files.pythonhosted.org/packages/d2/c4/623f7237c69606d202870bc9e44a8ed9070cc3eb1ac03f02c457083aa746/torchdata-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a43fc7e8d3ae2632859f15d5439cd97b83af559fecd8963e5f09e08f93b81e2", size = 1341201 }, +] + [[package]] name = "tqdm" version = "4.66.5" @@ -1793,10 +2043,12 @@ source = { editable = "." } dependencies = [ { name = "datasets" }, { name = "einops" }, + { name = "fsspec", extra = ["gcs"] }, { name = "numpy" }, { name = "pydantic-config" }, { name = "setuptools" }, { name = "torch" }, + { name = "torchdata" }, { name = "transformers" }, ] @@ -1816,10 +2068,12 @@ dev = [ requires-dist = [ { name = "datasets", specifier = ">=3.0.0" }, { name = "einops" }, + { name = "fsspec", extras = ["gcs"], specifier = ">=2024.3.1" }, { name = "numpy" }, { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c" }, { name = "setuptools" }, { name = "torch", specifier = "==2.4.1" }, + { name = "torchdata", specifier = ">=0.8.0" }, { name = "transformers", specifier = ">=4.44.2" }, { name = "wandb", marker = "extra == 'all'" }, ]