Skip to content

Commit

Permalink
add soap optimizer (#186)
Browse files Browse the repository at this point in the history
* add soap optimizer

* add soap optimizer

* wip add ckpt shampoo

* fix import issue

* fix diloco not defined

* fix ruff
  • Loading branch information
samsja authored Jan 10, 2025
1 parent a22f434 commit d958229
Show file tree
Hide file tree
Showing 23 changed files with 170 additions and 65 deletions.
10 changes: 6 additions & 4 deletions configs/10B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ sched_type = "wsd-sqrt"
batch_size = 128 #1M tokens bs
warmup_steps = 1000
total_steps = 1_000_000_000_000
lr = 7.5e-5

adam_betas1 = 0.9
adam_betas2 = 0.95
weight_decay = 0.1

z_loss = true

[optim.optim]
lr = 7.5e-5
betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1

[data]
seq_length = 8192
dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular,/data/datasets/dclm-baseline-1.0-parquet,/data/datasets/open-web-math"
Expand Down
11 changes: 6 additions & 5 deletions configs/10B/H100_cooldown.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ batch_size = 128 #1M tokens bs
warmup_steps = 1000
stable_steps = 74700
total_steps = 90400
lr = 7.5e-5

adam_betas1 = 0.9
adam_betas2 = 0.95
weight_decay = 0.1

z_loss = true

[optim.optim]
lr = 7.5e-5
betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1

[data]
seq_length = 8192
dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular"
Expand Down
13 changes: 7 additions & 6 deletions configs/10B/H100_simple.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ ac_ckpt = true

[optim]
sched_type = "wsd-sqrt"
batch_size = 128 #1M tokens bs
batch_size = 128 #1M tokens bs
warmup_steps = 1000
total_steps = 1_000_000_000_000
lr = 7.5e-5

adam_betas1 = 0.9
adam_betas2 = 0.95
weight_decay = 0.1

z_loss = true

[optim.optim]
lr = 7.5e-5
betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1

[data]
seq_length = 8192
num_workers = 4
2 changes: 2 additions & 0 deletions configs/13B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ ac_ckpt = true
batch_size = 1024 #2M tokens bs
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 3e-4

[data]
Expand Down
6 changes: 5 additions & 1 deletion configs/150M/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4


[optim.optim]
lr = 4e-4

5 changes: 4 additions & 1 deletion configs/150M/A40.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4

[optim.optim]
lr = 4e-4

5 changes: 4 additions & 1 deletion configs/150M/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4

[optim.optim]
lr = 4e-4

5 changes: 4 additions & 1 deletion configs/150M_short/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 500
total_steps = 8192
lr = 4e-4


[optim.optim]
lr = 4e-4
6 changes: 5 additions & 1 deletion configs/150M_short/A40.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ type_model = "llama2"
micro_bs = 32 # change this base on the gpu
reshard_after_forward = true


[optim]
batch_size = 512
warmup_steps = 500
total_steps = 8192
lr = 4e-4


[optim.optim]
lr = 4e-4
5 changes: 4 additions & 1 deletion configs/150M_short/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 500
total_steps = 8192
lr = 4e-4


[optim.optim]
lr = 4e-4
4 changes: 3 additions & 1 deletion configs/1B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ reshard_after_forward = true
batch_size = 1024
warmup_steps = 1000
total_steps = 8192
lr = 7e-4

[optim.optim]
lr = 7e-4
25 changes: 0 additions & 25 deletions configs/1B_diloco/H100.toml

This file was deleted.

2 changes: 2 additions & 0 deletions configs/7B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ micro_bs = 1
batch_size = 1024 #2M tokens bs
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 3e-4

[data]
Expand Down
2 changes: 2 additions & 0 deletions configs/7B_diloco/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ micro_bs = 1
batch_size = 1024 #2M tokens bs
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 3e-4

[data]
Expand Down
2 changes: 2 additions & 0 deletions configs/test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ num_workers = 1
batch_size = 128
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 4e-4
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"pyarrow",
"toposolve",
"psutil",
"torch-shampoo @ git+https://github.com/facebookresearch/optimizers.git@main",
]

[project.optional-dependencies]
Expand Down
25 changes: 16 additions & 9 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
send_state_dict,
send_tensor_and_state_dict,
)
from distributed_shampoo import DistributedShampoo

from zeroband.utils.world_info import get_world_info

Expand Down Expand Up @@ -80,17 +81,23 @@ def __init__(
self.optim = optim

def state_dict(self) -> dict[str, Any]:
return get_optimizer_state_dict(
model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True)
)
if isinstance(self.optim, DistributedShampoo):
return self.optim.distributed_state_dict(key_to_param=self.model.named_parameters())
else:
return get_optimizer_state_dict(
model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True)
)

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
set_optimizer_state_dict(
model=self.model,
optimizers=self.optim,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
if isinstance(self.optim, DistributedShampoo):
self.optim.load_distributed_state_dict(state_dict, key_to_param=self.model.named_parameters())
else:
set_optimizer_state_dict(
model=self.model,
optimizers=self.optim,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)


def cast_dtensor_to_tensor(state_dict: dict[str, Any]) -> dict[str, Any]:
Expand Down
3 changes: 3 additions & 0 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from zeroband.data import DataConfig
from zeroband.diloco import DilocoConfig
from zeroband.models.llama.model import AttnFnType
from zeroband.optimizers import OptimizersConfig, AdamConfig


class OptimConfig(BaseConfig):
optim: OptimizersConfig = AdamConfig()

lr: float = 4e-4
weight_decay: float = 0.1
adam_betas1: float = 0.9
Expand Down
60 changes: 60 additions & 0 deletions src/zeroband/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Literal, TypeAlias
from pydantic_config import BaseConfig
import torch
from distributed_shampoo import (
DefaultEigenvalueCorrectedShampooConfig,
DistributedShampoo,
FullyShardShampooConfig,
ShampooPT2CompileConfig,
)

class AdamConfig(BaseConfig):
type: Literal["adam"] = "adam" # the literal is used to distinguish between the different optimizers configuration in the union type
lr: float = 4e-4
weight_decay: float = 0.1
betas1: float = 0.9
betas2: float = 0.95

class SoapConfig(BaseConfig):
type: Literal["soap"] = "soap"
lr: float = 4e-4
weight_decay: float = 1e-05
betas1: float = 0.9
betas2: float = 0.95

max_preconditioner_dim: int = 8192
precondition_frequency: int = 100


OptimizersConfig: TypeAlias = AdamConfig | SoapConfig


def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer:
if isinstance(config, AdamConfig):
return torch.optim.AdamW(
params,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.betas1, config.betas2),
)
elif isinstance(config, SoapConfig):
return DistributedShampoo(
params,
lr=config.lr,
betas=(config.betas1, config.betas2),
epsilon=1e-12,
weight_decay=config.weight_decay,
max_preconditioner_dim=config.max_preconditioner_dim,
precondition_frequency=config.precondition_frequency,
use_decoupled_weight_decay=True,
# This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is
# less expensive and might thereby allow for a smaller `precondition_frequency`.
preconditioner_config=DefaultEigenvalueCorrectedShampooConfig,
distributed_config=FullyShardShampooConfig(),
shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False),
)
else:
raise ValueError(f"Unknown optimizer {config.optimizer}")


__all__ = ["OptimizersConfig", "get_optimizer"]
9 changes: 3 additions & 6 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from zeroband.diloco import Diloco
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss

from zeroband.models.llama.model import create_block_mask_from_seqlens
from zeroband.config import Config #, MemoryProfilerConfig
from zeroband.optimizers import get_optimizer

from zeroband.utils import (
FakeTokenizer,
Expand Down Expand Up @@ -162,12 +164,7 @@ def train(config: Config):
logger.debug("model fsdped")

# Setup optimizers
inner_optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.optim.lr,
weight_decay=config.optim.weight_decay,
betas=(config.optim.adam_betas1, config.optim.adam_betas2),
)
inner_optimizer = get_optimizer(model.parameters(), config.optim.optim)

diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None

Expand Down
4 changes: 4 additions & 0 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed._tensor.api import DTensor
from distributed_shampoo import DistributedShampoo


__all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"]
Expand Down Expand Up @@ -138,6 +139,9 @@ def get_optimizer_signature(optimizer: torch.optim.Optimizer, compress: bool = T
Get the optimizer signature
"""

if isinstance(optimizer, DistributedShampoo):
return "mocked signature because shampoo does not support state_dict()"

def unwrap_tensor(state_dict: dict) -> dict:
new_dict = {}
for key, value in state_dict.items():
Expand Down
Loading

0 comments on commit d958229

Please sign in to comment.