diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index ff142d26..7e55aafb 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -7,7 +7,7 @@ # Llama 2 is licensed under the LLAMA 2 Community License, # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from zeroband.models.llama.model import ModelArgs, Transformer +from zeroband.models.llama.model import AttnFnType, ModelArgs, Transformer __all__ = ["Transformer"] @@ -85,7 +85,7 @@ def get_model( type_model: str, vocab_size: int, seq_length: int, - math_attn: bool, + attn_fn: AttnFnType, ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" @@ -98,6 +98,6 @@ def get_model( config.vocab_size = vocab_size config.max_seq_len = seq_length - config.math_attn = math_attn + config.attn_fn = attn_fn return Transformer(config), config diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index c1a63403..cb767790 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -13,7 +13,7 @@ import contextlib from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple, TypeAlias import torch import torch.nn.functional as F @@ -41,6 +41,9 @@ def flex_attention_compiled( return _flex_attention_compiled(q, k, v, block_mask=block_mask) +AttnFnType: TypeAlias = Literal["flex", "math"] + + @dataclass class ModelArgs: dim: int = 4096 @@ -60,7 +63,7 @@ class ModelArgs: depth_init: bool = True norm_type: str = "fused_rmsnorm" - math_attn: bool = False # slow for testing + attn_fn: AttnFnType = "flex" # slow for testing def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: @@ -226,7 +229,7 @@ def __init__(self, model_args: ModelArgs): self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) - self.math_attn = model_args.math_attn + self.attn_fn = model_args.attn_fn def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -277,7 +280,7 @@ def forward( return self.wo(output) def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor: - with sdpa_kernel(SDPBackend.MATH) if self.math_attn else contextlib.nullcontext(): + with sdpa_kernel(SDPBackend.MATH) if self.attn_fn == "math" else contextlib.nullcontext(): output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) return output diff --git a/src/zeroband/train.py b/src/zeroband/train.py index a7f6d486..9f4a5f5c 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,7 +1,6 @@ import os from typing import Literal import time -import warnings from pydantic import model_validator from multiprocessing.process import _children @@ -19,7 +18,7 @@ from zeroband.diloco import Diloco, DilocoConfig from zeroband.comms import ElasticDeviceMesh from zeroband.loss import cross_entropy_max_z_loss -from zeroband.models.llama.model import create_block_mask_from_seqlens +from zeroband.models.llama.model import AttnFnType, create_block_mask_from_seqlens from zeroband.utils import ( FakeTokenizer, @@ -74,16 +73,8 @@ class TrainConfig(BaseConfig): memory_profiler: MemoryProfilerConfig | None = None sequence_packing: bool = True - attn_fn: Literal["flash", "sdpa"] | None = None - math_attn: bool = False # slow - - @model_validator(mode="after") - def validate_attn_fn(self): - if self.attn_fn is not None: - warnings.warn("attn_fn argument is deprecated") - - return self + attn_fn: AttnFnType = "flex" class MonitorConfig(BaseConfig): @@ -200,7 +191,7 @@ def train(config: Config): config.type_model, vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, seq_length=config.data.seq_length, - math_attn=config.train.math_attn, + attn_fn=config.train.attn_fn, ) model = model.to(world_info.local_rank) diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index dc36701a..aa04b0c5 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -144,7 +144,8 @@ def test_ckpt(tmp_path: Path): "20", "--train.log_model_hash", "--no-train.sequence_packing", - "--train.math_attn", + "--train.attn_fn", + "math", ], diloco=True, ) @@ -164,7 +165,8 @@ def test_ckpt(tmp_path: Path): "20", "--train.log_model_hash", "--no-train.sequence_packing", - "--train.math_attn", + "--train.attn_fn", + "math", ], diloco=True, ) @@ -184,7 +186,8 @@ def test_ckpt(tmp_path: Path): # "20", # "--train.log_model_hash", # "--no-train.sequence_packing", - # "--train.math_attn", + # "--train.attn_fn", + # "math", # ], # diloco=True, # )