Skip to content

Commit

Permalink
refactor math attn to attn fn
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Jan 6, 2025
1 parent 4715633 commit 5a76acb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 22 deletions.
6 changes: 3 additions & 3 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Llama 2 is licensed under the LLAMA 2 Community License,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from zeroband.models.llama.model import ModelArgs, Transformer
from zeroband.models.llama.model import AttnFnType, ModelArgs, Transformer

__all__ = ["Transformer"]

Expand Down Expand Up @@ -85,7 +85,7 @@ def get_model(
type_model: str,
vocab_size: int,
seq_length: int,
math_attn: bool,
attn_fn: AttnFnType,
) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

Expand All @@ -98,6 +98,6 @@ def get_model(

config.vocab_size = vocab_size
config.max_seq_len = seq_length
config.math_attn = math_attn
config.attn_fn = attn_fn

return Transformer(config), config
11 changes: 7 additions & 4 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import contextlib
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple, TypeAlias

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -41,6 +41,9 @@ def flex_attention_compiled(
return _flex_attention_compiled(q, k, v, block_mask=block_mask)


AttnFnType: TypeAlias = Literal["flex", "math"]


@dataclass
class ModelArgs:
dim: int = 4096
Expand All @@ -60,7 +63,7 @@ class ModelArgs:
depth_init: bool = True
norm_type: str = "fused_rmsnorm"

math_attn: bool = False # slow for testing
attn_fn: AttnFnType = "flex" # slow for testing


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
Expand Down Expand Up @@ -226,7 +229,7 @@ def __init__(self, model_args: ModelArgs):
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)

self.math_attn = model_args.math_attn
self.attn_fn = model_args.attn_fn

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand Down Expand Up @@ -277,7 +280,7 @@ def forward(
return self.wo(output)

def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
with sdpa_kernel(SDPBackend.MATH) if self.math_attn else contextlib.nullcontext():
with sdpa_kernel(SDPBackend.MATH) if self.attn_fn == "math" else contextlib.nullcontext():
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output
Expand Down
15 changes: 3 additions & 12 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from typing import Literal
import time
import warnings
from pydantic import model_validator
from multiprocessing.process import _children

Expand All @@ -19,7 +18,7 @@
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss
from zeroband.models.llama.model import create_block_mask_from_seqlens
from zeroband.models.llama.model import AttnFnType, create_block_mask_from_seqlens

from zeroband.utils import (
FakeTokenizer,
Expand Down Expand Up @@ -74,16 +73,8 @@ class TrainConfig(BaseConfig):
memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True
attn_fn: Literal["flash", "sdpa"] | None = None

math_attn: bool = False # slow

@model_validator(mode="after")
def validate_attn_fn(self):
if self.attn_fn is not None:
warnings.warn("attn_fn argument is deprecated")

return self
attn_fn: AttnFnType = "flex"


class MonitorConfig(BaseConfig):
Expand Down Expand Up @@ -200,7 +191,7 @@ def train(config: Config):
config.type_model,
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
seq_length=config.data.seq_length,
math_attn=config.train.math_attn,
attn_fn=config.train.attn_fn,
)

model = model.to(world_info.local_rank)
Expand Down
9 changes: 6 additions & 3 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def test_ckpt(tmp_path: Path):
"20",
"--train.log_model_hash",
"--no-train.sequence_packing",
"--train.math_attn",
"--train.attn_fn",
"math",
],
diloco=True,
)
Expand All @@ -164,7 +165,8 @@ def test_ckpt(tmp_path: Path):
"20",
"--train.log_model_hash",
"--no-train.sequence_packing",
"--train.math_attn",
"--train.attn_fn",
"math",
],
diloco=True,
)
Expand All @@ -184,7 +186,8 @@ def test_ckpt(tmp_path: Path):
# "20",
# "--train.log_model_hash",
# "--no-train.sequence_packing",
# "--train.math_attn",
# "--train.attn_fn",
# "math",
# ],
# diloco=True,
# )
Expand Down

0 comments on commit 5a76acb

Please sign in to comment.