diff --git a/src/zeroband/data.py b/src/zeroband/data.py index ed90f61c..9dc7c42d 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -46,18 +46,23 @@ def __init__(self, seq_len: int, vocab_size: int): self.seq_len = seq_len self.vocab_size = vocab_size assert vocab_size > 3, "Vocab size must be greater than 3" + self.step = 0 def __iter__(self) -> Generator[dict[str, Any], Any, None]: while True: len_ = random.randint(1, self.seq_len) input_ids = torch.randint(3, self.vocab_size, (len_,)).tolist() + self.step += 1 yield {"input_ids": input_ids} def state_dict(self): - return {} + return {"step": self.step} def load_state_dict(self, state_dict): - pass + self.step = state_dict["step"] + itera = iter(self) + for _ in range(self.step): + next(itera) class BatchOutput(TypedDict): diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 30b54963..ff142d26 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -85,6 +85,7 @@ def get_model( type_model: str, vocab_size: int, seq_length: int, + math_attn: bool, ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" @@ -97,5 +98,6 @@ def get_model( config.vocab_size = vocab_size config.max_seq_len = seq_length + config.math_attn = math_attn return Transformer(config), config diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index 234b8fee..c1a63403 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -11,6 +11,7 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import contextlib from dataclasses import dataclass from typing import Optional, Tuple @@ -20,6 +21,7 @@ from zeroband.models.norms import build_norm from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE +from torch.nn.attention import SDPBackend, sdpa_kernel _flex_attention_compiled = torch.compile(flex_attention, dynamic=False) @@ -58,6 +60,8 @@ class ModelArgs: depth_init: bool = True norm_type: str = "fused_rmsnorm" + math_attn: bool = False # slow for testing + def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """ @@ -222,6 +226,8 @@ def __init__(self, model_args: ModelArgs): self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) + self.math_attn = model_args.math_attn + def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) @@ -271,7 +277,8 @@ def forward( return self.wo(output) def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor: - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + with sdpa_kernel(SDPBackend.MATH) if self.math_attn else contextlib.nullcontext(): + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) return output diff --git a/src/zeroband/train.py b/src/zeroband/train.py index d54391e3..15d0cd93 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -79,6 +79,8 @@ class TrainConfig(BaseConfig): sequence_packing: bool = True attn_fn: Literal["flash", "sdpa"] | None = None + math_attn: bool = False # slow + @model_validator(mode="after") def validate_attn_fn(self): if self.attn_fn is not None: @@ -133,6 +135,7 @@ def log_hash_training_state( inner_optimizer: torch.optim.Optimizer, diloco: Diloco | None, metric_logger: MetricLogger, + step: int, id: str = "", ): """Log the hash of the model and optimizer. This function is slow""" @@ -143,10 +146,11 @@ def log_hash_training_state( logger.debug(f"inner diloco model {id} : {inner_model_hash}") logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}") - if world_info.rank == 0: - metric_logger.log( - {"inner_model_hash_{id}": inner_model_hash, "inner_optimizer_hash_{id}": inner_optimizer_hash} - ) + metrics = { + "step": step, + f"inner_model_hash_{id}": inner_model_hash, + f"inner_optimizer_hash_{id}": inner_optimizer_hash, + } if config.diloco is not None and diloco is not None: outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer) @@ -155,10 +159,11 @@ def log_hash_training_state( logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") - if world_info.rank == 0: - metric_logger.log( - {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} - ) + metrics.update( + {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} + ) + if world_info.rank == 0: + metric_logger.log(metrics) def train(config: Config): @@ -198,6 +203,7 @@ def train(config: Config): config.type_model, vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, seq_length=config.data.seq_length, + math_attn=config.train.math_attn, ) model = model.to(world_info.local_rank) @@ -300,7 +306,9 @@ def train(config: Config): skip_dataloader=config.ckpt.skip_dataloader, data_path=config.ckpt.data_path, ) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="resume") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" + ) if config.train.memory_monitor: gpu_mem_monitor = GPUMemoryMonitor() @@ -349,7 +357,15 @@ def train(config: Config): ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="live_reco_recv") + log_hash_training_state( + config, + model, + inner_optimizer, + diloco, + metric_logger, + step=training_progress.step, + id="live_reco_recv", + ) need_live_recovery = False if config.ckpt.remote_data_load: @@ -483,7 +499,9 @@ def train(config: Config): diloco.step(model=model, flag=training_progress.outer_step) diloco_time = time.perf_counter() - time_start_inner - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="outer_step") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step" + ) training_progress.outer_step += 1 @@ -496,7 +514,9 @@ def train(config: Config): do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 ckpt_manager.save(remote=do_remote) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="ckpt save") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save" + ) if config.diloco: tokens_per_second = ( diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index e5703fe3..2cadc041 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -1,5 +1,7 @@ import copy import os +from pathlib import Path +import pickle import subprocess import pytest import socket @@ -112,3 +114,139 @@ def test_packing(packing: bool): num_gpus = [2, 1] packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing" _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) + + +def test_ckpt(tmp_path: Path): + num_gpus = [1, 2] + v1_file = tmp_path / "v1.log" + v2_file = tmp_path / "v2.log" + v3_file = tmp_path / "v3.log" + + v1_ckpt = tmp_path / "v1_ckpt" + v2_ckpt = tmp_path / "v2_ckpt" + v3_ckpt = tmp_path / "v3_ckpt" + + os.mkdir(v1_ckpt) + os.mkdir(v2_ckpt) + os.mkdir(v3_ckpt) + + _test_multi_gpu( + num_gpus, + "debug/diloco.toml", + extra_args=[ + "--project", + str(v1_file), + "--ckpt.path", + str(v1_ckpt), + "--ckpt.interval", + "5", + "--optim.total_steps", + "20", + "--train.log_model_hash", + "--no-train.sequence_packing", + "--train.math_attn", + ], + diloco=True, + ) + _test_multi_gpu( + num_gpus, + "debug/diloco.toml", + extra_args=[ + "--project", + str(v2_file), + "--ckpt.path", + str(v2_ckpt), + "--ckpt.interval", + "5", + "--ckpt.resume", + str(v1_ckpt / "step_5"), + "--optim.total_steps", + "20", + "--train.log_model_hash", + "--no-train.sequence_packing", + "--train.math_attn", + ], + diloco=True, + ) + _test_multi_gpu( + num_gpus, + "debug/diloco.toml", + extra_args=[ + "--project", + str(v3_file), + "--ckpt.path", + str(v3_ckpt), + "--ckpt.interval", + "5", + "--ckpt.resume", + str(v2_ckpt / "step_10"), + "--optim.total_steps", + "20", + "--train.log_model_hash", + "--no-train.sequence_packing", + "--train.math_attn", + ], + diloco=True, + ) + + key_to_round = ["Perplexity", "Loss"] + digit_to_round = [0, 2] + + def read_logs(path: Path): + with path.open("rb") as f: + data = pickle.load(f) + + filtered_data = {} + for entry in data: + step = entry.pop("step") + + # Round perplexity and loss + for key, digit in zip(key_to_round, digit_to_round): + if key in entry: + entry[key] = round(entry[key], digit) + + if step in filtered_data: + filtered_data[step].update(entry) + else: + filtered_data[step] = entry + + return filtered_data + + v1_data = read_logs(v1_file) + v2_data = read_logs(v2_file) + v3_data = read_logs(v3_file) + + ## check that loading from v1 to v2 worked + + # first check that the hash of saving is the same as the hash of loading + assert v1_data[5]["inner_model_hash_save"] == v2_data[5]["inner_model_hash_resume"] + assert v1_data[5]["inner_optimizer_hash_save"] == v2_data[5]["inner_optimizer_hash_resume"] + assert v1_data[5]["outer_optimizer_hash_save"] == v2_data[5]["outer_optimizer_hash_resume"] + assert v1_data[5]["outer_model_hash_save"] == v2_data[5]["outer_model_hash_resume"] + + # then we check that the loss and lr value are the same after loading the ckpt + for step, data_v2 in v2_data.items(): + if step == 5: + continue # not testing 5 as ts the one were we restarted from + + data_v1 = v1_data[step] + assert data_v1["Loss"] == data_v2["Loss"] + assert data_v1["inner_lr"] == data_v2["inner_lr"] + assert data_v1["total_tokens"] == data_v2["total_tokens"] + + ## check that the second loading is working + ## why ? We had bugs where ckpt was working but not when the training was resuming + + assert v2_data[10]["inner_model_hash_save"] == v3_data[10]["inner_model_hash_resume"] + assert v2_data[10]["inner_optimizer_hash_save"] == v3_data[10]["inner_optimizer_hash_resume"] + assert v2_data[10]["outer_optimizer_hash_save"] == v3_data[10]["outer_optimizer_hash_resume"] + assert v2_data[10]["outer_model_hash_save"] == v3_data[10]["outer_model_hash_resume"] + + for step, data_v3 in v3_data.items(): + if step == 10: + continue # not testing 10 as ts the one were we restarted from + + data_v2 = v2_data[step] + assert data_v2["Loss"] == data_v3["Loss"] + assert data_v2["inner_lr"] == data_v3["inner_lr"] + assert data_v2["total_tokens"] == data_v3["total_tokens"]