diff --git a/configs/10B/H100.toml b/configs/10B/H100.toml index d58e6098..d743cc8a 100644 --- a/configs/10B/H100.toml +++ b/configs/10B/H100.toml @@ -11,14 +11,16 @@ sched_type = "wsd-sqrt" batch_size = 128 #1M tokens bs warmup_steps = 1000 total_steps = 1_000_000_000_000 -lr = 7.5e-5 -adam_betas1 = 0.9 -adam_betas2 = 0.95 -weight_decay = 0.1 z_loss = true +[optim.optim] +lr = 7.5e-5 +betas1 = 0.9 +betas2 = 0.95 +weight_decay = 0.1 + [data] seq_length = 8192 dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular,/data/datasets/dclm-baseline-1.0-parquet,/data/datasets/open-web-math" diff --git a/configs/10B/H100_cooldown.toml b/configs/10B/H100_cooldown.toml index 9132b1e8..c443e0ed 100644 --- a/configs/10B/H100_cooldown.toml +++ b/configs/10B/H100_cooldown.toml @@ -12,14 +12,15 @@ batch_size = 128 #1M tokens bs warmup_steps = 1000 stable_steps = 74700 total_steps = 90400 -lr = 7.5e-5 - -adam_betas1 = 0.9 -adam_betas2 = 0.95 -weight_decay = 0.1 z_loss = true +[optim.optim] +lr = 7.5e-5 +betas1 = 0.9 +betas2 = 0.95 +weight_decay = 0.1 + [data] seq_length = 8192 dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular" diff --git a/configs/10B/H100_simple.toml b/configs/10B/H100_simple.toml index dafa64b1..6e8ca505 100644 --- a/configs/10B/H100_simple.toml +++ b/configs/10B/H100_simple.toml @@ -7,17 +7,18 @@ ac_ckpt = true [optim] sched_type = "wsd-sqrt" -batch_size = 128 #1M tokens bs +batch_size = 128 #1M tokens bs warmup_steps = 1000 total_steps = 1_000_000_000_000 -lr = 7.5e-5 - -adam_betas1 = 0.9 -adam_betas2 = 0.95 -weight_decay = 0.1 z_loss = true +[optim.optim] +lr = 7.5e-5 +betas1 = 0.9 +betas2 = 0.95 +weight_decay = 0.1 + [data] seq_length = 8192 num_workers = 4 diff --git a/configs/13B/H100.toml b/configs/13B/H100.toml index 176a9a12..4bfc3e05 100644 --- a/configs/13B/H100.toml +++ b/configs/13B/H100.toml @@ -9,6 +9,8 @@ ac_ckpt = true batch_size = 1024 #2M tokens bs warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 3e-4 [data] diff --git a/configs/150M/3090.toml b/configs/150M/3090.toml index 761d1b66..a304abd8 100644 --- a/configs/150M/3090.toml +++ b/configs/150M/3090.toml @@ -10,4 +10,8 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 1000 total_steps = 88_000 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 + diff --git a/configs/150M/A40.toml b/configs/150M/A40.toml index c82f2df4..ddbef1a5 100644 --- a/configs/150M/A40.toml +++ b/configs/150M/A40.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 1000 total_steps = 88_000 -lr = 4e-4 \ No newline at end of file + +[optim.optim] +lr = 4e-4 + diff --git a/configs/150M/H100.toml b/configs/150M/H100.toml index b15c1750..a6339181 100644 --- a/configs/150M/H100.toml +++ b/configs/150M/H100.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 1000 total_steps = 88_000 -lr = 4e-4 \ No newline at end of file + +[optim.optim] +lr = 4e-4 + diff --git a/configs/150M_short/3090.toml b/configs/150M_short/3090.toml index 4792bc1b..a468b64c 100644 --- a/configs/150M_short/3090.toml +++ b/configs/150M_short/3090.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 500 total_steps = 8192 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 diff --git a/configs/150M_short/A40.toml b/configs/150M_short/A40.toml index 17aa7aca..80756de5 100644 --- a/configs/150M_short/A40.toml +++ b/configs/150M_short/A40.toml @@ -6,8 +6,12 @@ type_model = "llama2" micro_bs = 32 # change this base on the gpu reshard_after_forward = true + [optim] batch_size = 512 warmup_steps = 500 total_steps = 8192 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 diff --git a/configs/150M_short/H100.toml b/configs/150M_short/H100.toml index af7582e0..f7a7223d 100644 --- a/configs/150M_short/H100.toml +++ b/configs/150M_short/H100.toml @@ -10,4 +10,7 @@ reshard_after_forward = true batch_size = 512 warmup_steps = 500 total_steps = 8192 -lr = 4e-4 \ No newline at end of file + + +[optim.optim] +lr = 4e-4 diff --git a/configs/1B/H100.toml b/configs/1B/H100.toml index 4d5a325e..de9cef75 100644 --- a/configs/1B/H100.toml +++ b/configs/1B/H100.toml @@ -10,4 +10,6 @@ reshard_after_forward = true batch_size = 1024 warmup_steps = 1000 total_steps = 8192 -lr = 7e-4 \ No newline at end of file + +[optim.optim] +lr = 7e-4 diff --git a/configs/1B_diloco/H100.toml b/configs/1B_diloco/H100.toml deleted file mode 100644 index 19d3259f..00000000 --- a/configs/1B_diloco/H100.toml +++ /dev/null @@ -1,25 +0,0 @@ -name_model = "1B" -project = "debug_1B_zero_band" -type_model = "llama2" - -[train] -micro_bs = 16 - -[optim] -batch_size = 2048 -warmup_steps = 1000 -total_steps = 88_000 -lr = 4e-4 - -z_loss = true - - -[diloco] -inner_steps = 50 -compression = "uint8" - -[ckpt] -interval = 50 -topk = 3 -path = "outputs_1b_diloco_50" - diff --git a/configs/7B/H100.toml b/configs/7B/H100.toml index f701ef7c..7ea3dc65 100644 --- a/configs/7B/H100.toml +++ b/configs/7B/H100.toml @@ -9,6 +9,8 @@ micro_bs = 1 batch_size = 1024 #2M tokens bs warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 3e-4 [data] diff --git a/configs/7B_diloco/H100.toml b/configs/7B_diloco/H100.toml index ceeccc43..b6a84d2c 100644 --- a/configs/7B_diloco/H100.toml +++ b/configs/7B_diloco/H100.toml @@ -9,6 +9,8 @@ micro_bs = 1 batch_size = 1024 #2M tokens bs warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 3e-4 [data] diff --git a/configs/test.toml b/configs/test.toml index 46abc536..d9f9726d 100644 --- a/configs/test.toml +++ b/configs/test.toml @@ -15,4 +15,6 @@ num_workers = 1 batch_size = 128 warmup_steps = 1000 total_steps = 88_000 + +[optim.optim] lr = 4e-4 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e4e0786e..60bab574 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pyarrow", "toposolve", "psutil", + "torch-shampoo @ git+https://github.com/facebookresearch/optimizers.git@main", ] [project.optional-dependencies] diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index b240c4b5..4acd9da3 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -38,6 +38,7 @@ send_state_dict, send_tensor_and_state_dict, ) +from distributed_shampoo import DistributedShampoo from zeroband.utils.world_info import get_world_info @@ -80,17 +81,23 @@ def __init__( self.optim = optim def state_dict(self) -> dict[str, Any]: - return get_optimizer_state_dict( - model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True) - ) + if isinstance(self.optim, DistributedShampoo): + return self.optim.distributed_state_dict(key_to_param=self.model.named_parameters()) + else: + return get_optimizer_state_dict( + model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True) + ) def load_state_dict(self, state_dict: dict[str, Any]) -> None: - set_optimizer_state_dict( - model=self.model, - optimizers=self.optim, - optim_state_dict=state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) + if isinstance(self.optim, DistributedShampoo): + self.optim.load_distributed_state_dict(state_dict, key_to_param=self.model.named_parameters()) + else: + set_optimizer_state_dict( + model=self.model, + optimizers=self.optim, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) def cast_dtensor_to_tensor(state_dict: dict[str, Any]) -> dict[str, Any]: diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 17745a4a..07d4b7e0 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -7,9 +7,12 @@ from zeroband.data import DataConfig from zeroband.diloco import DilocoConfig from zeroband.models.llama.model import AttnFnType +from zeroband.optimizers import OptimizersConfig, AdamConfig class OptimConfig(BaseConfig): + optim: OptimizersConfig = AdamConfig() + lr: float = 4e-4 weight_decay: float = 0.1 adam_betas1: float = 0.9 diff --git a/src/zeroband/optimizers.py b/src/zeroband/optimizers.py new file mode 100644 index 00000000..8ef1985f --- /dev/null +++ b/src/zeroband/optimizers.py @@ -0,0 +1,60 @@ +from typing import Literal, TypeAlias +from pydantic_config import BaseConfig +import torch +from distributed_shampoo import ( + DefaultEigenvalueCorrectedShampooConfig, + DistributedShampoo, + FullyShardShampooConfig, + ShampooPT2CompileConfig, +) + +class AdamConfig(BaseConfig): + type: Literal["adam"] = "adam" # the literal is used to distinguish between the different optimizers configuration in the union type + lr: float = 4e-4 + weight_decay: float = 0.1 + betas1: float = 0.9 + betas2: float = 0.95 + +class SoapConfig(BaseConfig): + type: Literal["soap"] = "soap" + lr: float = 4e-4 + weight_decay: float = 1e-05 + betas1: float = 0.9 + betas2: float = 0.95 + + max_preconditioner_dim: int = 8192 + precondition_frequency: int = 100 + + +OptimizersConfig: TypeAlias = AdamConfig | SoapConfig + + +def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer: + if isinstance(config, AdamConfig): + return torch.optim.AdamW( + params, + lr=config.lr, + weight_decay=config.weight_decay, + betas=(config.betas1, config.betas2), + ) + elif isinstance(config, SoapConfig): + return DistributedShampoo( + params, + lr=config.lr, + betas=(config.betas1, config.betas2), + epsilon=1e-12, + weight_decay=config.weight_decay, + max_preconditioner_dim=config.max_preconditioner_dim, + precondition_frequency=config.precondition_frequency, + use_decoupled_weight_decay=True, + # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is + # less expensive and might thereby allow for a smaller `precondition_frequency`. + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, + distributed_config=FullyShardShampooConfig(), + shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False), + ) + else: + raise ValueError(f"Unknown optimizer {config.optimizer}") + + +__all__ = ["OptimizersConfig", "get_optimizer"] diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ad86608e..87f4635d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -16,8 +16,10 @@ from zeroband.diloco import Diloco from zeroband.comms import ElasticDeviceMesh from zeroband.loss import cross_entropy_max_z_loss + from zeroband.models.llama.model import create_block_mask_from_seqlens from zeroband.config import Config #, MemoryProfilerConfig +from zeroband.optimizers import get_optimizer from zeroband.utils import ( FakeTokenizer, @@ -162,12 +164,7 @@ def train(config: Config): logger.debug("model fsdped") # Setup optimizers - inner_optimizer = torch.optim.AdamW( - model.parameters(), - lr=config.optim.lr, - weight_decay=config.optim.weight_decay, - betas=(config.optim.adam_betas1, config.optim.adam_betas2), - ) + inner_optimizer = get_optimizer(model.parameters(), config.optim.optim) diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index c0ea3699..f6b1c915 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -4,6 +4,7 @@ import torch from torch.distributed.fsdp import ShardingStrategy from torch.distributed._tensor.api import DTensor +from distributed_shampoo import DistributedShampoo __all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"] @@ -138,6 +139,9 @@ def get_optimizer_signature(optimizer: torch.optim.Optimizer, compress: bool = T Get the optimizer signature """ + if isinstance(optimizer, DistributedShampoo): + return "mocked signature because shampoo does not support state_dict()" + def unwrap_tensor(state_dict: dict) -> dict: new_dict = {} for key, value in state_dict.items(): diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index aa04b0c5..7b33a620 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -116,7 +116,19 @@ def test_packing(packing: bool): _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) -def test_ckpt(tmp_path: Path): +@pytest.mark.parametrize("diloco", [False, True]) +def test_soap(diloco: bool): + num_gpus = [1, 2] if diloco else [2, 1] + _test_multi_gpu( + num_gpus, + "debug/diloco.toml" if diloco else "debug/normal.toml", + extra_args=["--optim.optim.precondition_frequency", "1"], + diloco=diloco, + ) + + +@pytest.mark.parametrize("soap", [False, True]) +def test_ckpt(tmp_path: Path, soap: bool): num_gpus = [1, 2] v1_file = tmp_path / "v1.log" v2_file = tmp_path / "v2.log" @@ -146,7 +158,8 @@ def test_ckpt(tmp_path: Path): "--no-train.sequence_packing", "--train.attn_fn", "math", - ], + ] + + (["--optim.optim.precondition_frequency", "1"] if soap else []), diloco=True, ) _test_multi_gpu( @@ -167,7 +180,8 @@ def test_ckpt(tmp_path: Path): "--no-train.sequence_packing", "--train.attn_fn", "math", - ], + ] + + (["--optim.optim.precondition_frequency", "1"] if soap else []), diloco=True, ) # _test_multi_gpu( diff --git a/uv.lock b/uv.lock index b94a8576..cc6fdf7a 100644 --- a/uv.lock +++ b/uv.lock @@ -2413,6 +2413,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/72/20cb30f3b39a9face296491a86adb6ff8f1a47a897e4d14667e6cf89d5c3/torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7", size = 906393265 }, ] +[[package]] +name = "torch-shampoo" +version = "1.0.0" +source = { git = "https://github.com/facebookresearch/optimizers.git?rev=main#c51e4e6c0a9a6e93163441a9b32bb65cc1c736a8" } +dependencies = [ + { name = "torch" }, +] + [[package]] name = "torchdata" version = "0.8.0" @@ -2749,6 +2757,7 @@ dependencies = [ { name = "setuptools" }, { name = "toposolve" }, { name = "torch" }, + { name = "torch-shampoo" }, { name = "torchdata" }, { name = "transformers" }, { name = "zstandard" }, @@ -2788,6 +2797,7 @@ requires-dist = [ { name = "setuptools" }, { name = "toposolve" }, { name = "torch", specifier = "==2.5.1" }, + { name = "torch-shampoo", git = "https://github.com/facebookresearch/optimizers.git?rev=main" }, { name = "torchdata", specifier = ">=0.8.0" }, { name = "transformers", specifier = ">=4.44.2" }, { name = "wandb", marker = "extra == 'all'" },