Skip to content

Commit

Permalink
add ckpt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 19, 2024
1 parent 1a3b439 commit 65e7b74
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 15 deletions.
9 changes: 7 additions & 2 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,23 @@ def __init__(self, seq_len: int, vocab_size: int):
self.seq_len = seq_len
self.vocab_size = vocab_size
assert vocab_size > 3, "Vocab size must be greater than 3"
self.step = 0

def __iter__(self) -> Generator[dict[str, Any], Any, None]:
while True:
len_ = random.randint(1, self.seq_len)
input_ids = torch.randint(3, self.vocab_size, (len_,)).tolist()
self.step += 1
yield {"input_ids": input_ids}

def state_dict(self):
return {}
return {"step": self.step}

def load_state_dict(self, state_dict):
pass
self.step = state_dict["step"]
itera = iter(self)
for _ in range(self.step):
next(itera)


class BatchOutput(TypedDict):
Expand Down
2 changes: 2 additions & 0 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_model(
type_model: str,
vocab_size: int,
seq_length: int,
math_attn: bool,
) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

Expand All @@ -97,5 +98,6 @@ def get_model(

config.vocab_size = vocab_size
config.max_seq_len = seq_length
config.math_attn = math_attn

return Transformer(config), config
9 changes: 8 additions & 1 deletion src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


import contextlib
from dataclasses import dataclass
from typing import Optional, Tuple

Expand All @@ -20,6 +21,7 @@
from zeroband.models.norms import build_norm

from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE
from torch.nn.attention import SDPBackend, sdpa_kernel

_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)

Expand Down Expand Up @@ -58,6 +60,8 @@ class ModelArgs:
depth_init: bool = True
norm_type: str = "fused_rmsnorm"

math_attn: bool = False # slow for testing


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Expand Down Expand Up @@ -222,6 +226,8 @@ def __init__(self, model_args: ModelArgs):
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)

self.math_attn = model_args.math_attn

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
Expand Down Expand Up @@ -271,7 +277,8 @@ def forward(
return self.wo(output)

def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
with sdpa_kernel(SDPBackend.MATH) if self.math_attn else contextlib.nullcontext():
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

Expand Down
44 changes: 32 additions & 12 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class TrainConfig(BaseConfig):
sequence_packing: bool = True
attn_fn: Literal["flash", "sdpa"] | None = None

math_attn: bool = False # slow

@model_validator(mode="after")
def validate_attn_fn(self):
if self.attn_fn is not None:
Expand Down Expand Up @@ -133,6 +135,7 @@ def log_hash_training_state(
inner_optimizer: torch.optim.Optimizer,
diloco: Diloco | None,
metric_logger: MetricLogger,
step: int,
id: str = "",
):
"""Log the hash of the model and optimizer. This function is slow"""
Expand All @@ -143,10 +146,11 @@ def log_hash_training_state(
logger.debug(f"inner diloco model {id} : {inner_model_hash}")
logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}")

if world_info.rank == 0:
metric_logger.log(
{"inner_model_hash_{id}": inner_model_hash, "inner_optimizer_hash_{id}": inner_optimizer_hash}
)
metrics = {
"step": step,
f"inner_model_hash_{id}": inner_model_hash,
f"inner_optimizer_hash_{id}": inner_optimizer_hash,
}

if config.diloco is not None and diloco is not None:
outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer)
Expand All @@ -155,10 +159,11 @@ def log_hash_training_state(
logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")

if world_info.rank == 0:
metric_logger.log(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)
metrics.update(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)
if world_info.rank == 0:
metric_logger.log(metrics)


def train(config: Config):
Expand Down Expand Up @@ -198,6 +203,7 @@ def train(config: Config):
config.type_model,
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
seq_length=config.data.seq_length,
math_attn=config.train.math_attn,
)

model = model.to(world_info.local_rank)
Expand Down Expand Up @@ -300,7 +306,9 @@ def train(config: Config):
skip_dataloader=config.ckpt.skip_dataloader,
data_path=config.ckpt.data_path,
)
log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="resume")
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume"
)

if config.train.memory_monitor:
gpu_mem_monitor = GPUMemoryMonitor()
Expand Down Expand Up @@ -349,7 +357,15 @@ def train(config: Config):

ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg)

log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="live_reco_recv")
log_hash_training_state(
config,
model,
inner_optimizer,
diloco,
metric_logger,
step=training_progress.step,
id="live_reco_recv",
)
need_live_recovery = False

if config.ckpt.remote_data_load:
Expand Down Expand Up @@ -483,7 +499,9 @@ def train(config: Config):
diloco.step(model=model, flag=training_progress.outer_step)
diloco_time = time.perf_counter() - time_start_inner

log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="outer_step")
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step"
)

training_progress.outer_step += 1

Expand All @@ -496,7 +514,9 @@ def train(config: Config):

do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0
ckpt_manager.save(remote=do_remote)
log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="ckpt save")
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save"
)

if config.diloco:
tokens_per_second = (
Expand Down
138 changes: 138 additions & 0 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import os
from pathlib import Path
import pickle
import subprocess
import pytest
import socket
Expand Down Expand Up @@ -112,3 +114,139 @@ def test_packing(packing: bool):
num_gpus = [2, 1]
packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing"
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg])


def test_ckpt(tmp_path: Path):
num_gpus = [1, 2]
v1_file = tmp_path / "v1.log"
v2_file = tmp_path / "v2.log"
v3_file = tmp_path / "v3.log"

v1_ckpt = tmp_path / "v1_ckpt"
v2_ckpt = tmp_path / "v2_ckpt"
v3_ckpt = tmp_path / "v3_ckpt"

os.mkdir(v1_ckpt)
os.mkdir(v2_ckpt)
os.mkdir(v3_ckpt)

_test_multi_gpu(
num_gpus,
"debug/diloco.toml",
extra_args=[
"--project",
str(v1_file),
"--ckpt.path",
str(v1_ckpt),
"--ckpt.interval",
"5",
"--optim.total_steps",
"20",
"--train.log_model_hash",
"--no-train.sequence_packing",
"--train.math_attn",
],
diloco=True,
)
_test_multi_gpu(
num_gpus,
"debug/diloco.toml",
extra_args=[
"--project",
str(v2_file),
"--ckpt.path",
str(v2_ckpt),
"--ckpt.interval",
"5",
"--ckpt.resume",
str(v1_ckpt / "step_5"),
"--optim.total_steps",
"20",
"--train.log_model_hash",
"--no-train.sequence_packing",
"--train.math_attn",
],
diloco=True,
)
_test_multi_gpu(
num_gpus,
"debug/diloco.toml",
extra_args=[
"--project",
str(v3_file),
"--ckpt.path",
str(v3_ckpt),
"--ckpt.interval",
"5",
"--ckpt.resume",
str(v2_ckpt / "step_10"),
"--optim.total_steps",
"20",
"--train.log_model_hash",
"--no-train.sequence_packing",
"--train.math_attn",
],
diloco=True,
)

key_to_round = ["Perplexity", "Loss"]
digit_to_round = [0, 2]

def read_logs(path: Path):
with path.open("rb") as f:
data = pickle.load(f)

filtered_data = {}
for entry in data:
step = entry.pop("step")

# Round perplexity and loss
for key, digit in zip(key_to_round, digit_to_round):
if key in entry:
entry[key] = round(entry[key], digit)

if step in filtered_data:
filtered_data[step].update(entry)
else:
filtered_data[step] = entry

return filtered_data

v1_data = read_logs(v1_file)
v2_data = read_logs(v2_file)
v3_data = read_logs(v3_file)

## check that loading from v1 to v2 worked

# first check that the hash of saving is the same as the hash of loading
assert v1_data[5]["inner_model_hash_save"] == v2_data[5]["inner_model_hash_resume"]
assert v1_data[5]["inner_optimizer_hash_save"] == v2_data[5]["inner_optimizer_hash_resume"]
assert v1_data[5]["outer_optimizer_hash_save"] == v2_data[5]["outer_optimizer_hash_resume"]
assert v1_data[5]["outer_model_hash_save"] == v2_data[5]["outer_model_hash_resume"]

# then we check that the loss and lr value are the same after loading the ckpt
for step, data_v2 in v2_data.items():
if step == 5:
continue # not testing 5 as ts the one were we restarted from

data_v1 = v1_data[step]
assert data_v1["Loss"] == data_v2["Loss"]
assert data_v1["inner_lr"] == data_v2["inner_lr"]
assert data_v1["total_tokens"] == data_v2["total_tokens"]

## check that the second loading is working
## why ? We had bugs where ckpt was working but not when the training was resuming

assert v2_data[10]["inner_model_hash_save"] == v3_data[10]["inner_model_hash_resume"]
assert v2_data[10]["inner_optimizer_hash_save"] == v3_data[10]["inner_optimizer_hash_resume"]
assert v2_data[10]["outer_optimizer_hash_save"] == v3_data[10]["outer_optimizer_hash_resume"]
assert v2_data[10]["outer_model_hash_save"] == v3_data[10]["outer_model_hash_resume"]

for step, data_v3 in v3_data.items():
if step == 10:
continue # not testing 10 as ts the one were we restarted from

data_v2 = v2_data[step]
assert data_v2["Loss"] == data_v3["Loss"]
assert data_v2["inner_lr"] == data_v3["inner_lr"]
assert data_v2["total_tokens"] == data_v3["total_tokens"]

0 comments on commit 65e7b74

Please sign in to comment.