Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up and add ckpt tests #179

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,23 @@ def __init__(self, seq_len: int, vocab_size: int):
self.seq_len = seq_len
self.vocab_size = vocab_size
assert vocab_size > 3, "Vocab size must be greater than 3"
self.step = 0

def __iter__(self) -> Generator[dict[str, Any], Any, None]:
while True:
len_ = random.randint(1, self.seq_len)
input_ids = torch.randint(3, self.vocab_size, (len_,)).tolist()
self.step += 1
yield {"input_ids": input_ids}

def state_dict(self):
return {}
return {"step": self.step}

def load_state_dict(self, state_dict):
pass
self.step = state_dict["step"]
itera = iter(self)
for _ in range(self.step):
next(itera)


class BatchOutput(TypedDict):
Expand Down
2 changes: 2 additions & 0 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_model(
type_model: str,
vocab_size: int,
seq_length: int,
math_attn: bool,
) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

Expand All @@ -97,5 +98,6 @@ def get_model(

config.vocab_size = vocab_size
config.max_seq_len = seq_length
config.math_attn = math_attn

return Transformer(config), config
9 changes: 8 additions & 1 deletion src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


import contextlib
from dataclasses import dataclass
from typing import Optional, Tuple

Expand All @@ -20,6 +21,7 @@
from zeroband.models.norms import build_norm

from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE
from torch.nn.attention import SDPBackend, sdpa_kernel

_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)

Expand Down Expand Up @@ -58,6 +60,8 @@ class ModelArgs:
depth_init: bool = True
norm_type: str = "fused_rmsnorm"

math_attn: bool = False # slow for testing


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Expand Down Expand Up @@ -222,6 +226,8 @@ def __init__(self, model_args: ModelArgs):
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)

self.math_attn = model_args.math_attn

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
Expand Down Expand Up @@ -271,7 +277,8 @@ def forward(
return self.wo(output)

def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
with sdpa_kernel(SDPBackend.MATH) if self.math_attn else contextlib.nullcontext():
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
return output

Expand Down
128 changes: 68 additions & 60 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Literal
import time
import warnings
import psutil
from pydantic import model_validator
from multiprocessing.process import _children

Expand All @@ -24,15 +23,14 @@

from zeroband.utils import (
FakeTokenizer,
GPUMemoryMonitor,
PerfCounter,
get_module_signature,
get_optimizer_signature,
get_tensor_list_signature,
)
from zeroband.utils.activation_ckpt import apply_ac_ckpt
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader, DataConfig
from zeroband.utils.metric_logger import WandbMetricLogger, DummyMetricLogger
from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger
from zeroband.utils.monitor import HttpMonitor
from zeroband.models.llama import get_model
from zeroband.utils.profiler import MemoryProfiler
Expand Down Expand Up @@ -73,12 +71,13 @@ class TrainConfig(BaseConfig):

log_model_hash: bool = False

memory_monitor: bool = False
memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True
attn_fn: Literal["flash", "sdpa"] | None = None

math_attn: bool = False # slow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about putting this as an option in attn_fn instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    attn_fn: Literal["flash", "sdpa"] | None = None

    @model_validator(mode="after")
    def validate_attn_fn(self):
        if self.attn_fn is not None:
            warnings.warn("attn_fn argument is deprecated")

        return self

hmm attn_fn is not used anymore. I just kept it to avoid conflict with old code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pr to remove attn_fn #180


@model_validator(mode="after")
def validate_attn_fn(self):
if self.attn_fn is not None:
Expand Down Expand Up @@ -127,6 +126,43 @@ def validate_live_recovery_rank_src(self):
return self


def log_hash_training_state(
config: Config,
model: torch.nn.Module,
inner_optimizer: torch.optim.Optimizer,
diloco: Diloco | None,
metric_logger: MetricLogger,
step: int,
id: str = "",
):
"""Log the hash of the model and optimizer. This function is slow"""
if config.train.log_model_hash:
inner_model_hash = get_module_signature(model)
inner_optimizer_hash = get_optimizer_signature(inner_optimizer)

logger.debug(f"inner diloco model {id} : {inner_model_hash}")
logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}")

metrics = {
"step": step,
f"inner_model_hash_{id}": inner_model_hash,
f"inner_optimizer_hash_{id}": inner_optimizer_hash,
}

if config.diloco is not None and diloco is not None:
outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer)
outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu)

logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")

metrics.update(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)
if world_info.rank == 0:
metric_logger.log(metrics)


def train(config: Config):
# batch_size is the total batch size for all GPUs
assert config.optim.batch_size % world_info.local_world_size == 0
Expand Down Expand Up @@ -164,6 +200,7 @@ def train(config: Config):
config.type_model,
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
seq_length=config.data.seq_length,
math_attn=config.train.math_attn,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about passing attn_fn instead? Would also allow sdpa to be specified

)

model = model.to(world_info.local_rank)
Expand Down Expand Up @@ -244,6 +281,16 @@ def train(config: Config):
diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None,
)

if world_info.rank == 0:
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
metric_logger = logger_cls(
project=config.project,
config={"config": config.model_dump(), "world_info": world_info.json()},
resume=config.wandb_resume,
)
else:
metric_logger = None

if config.train.torch_compile:
# we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY
model = torch.compile(model)
Expand All @@ -256,24 +303,10 @@ def train(config: Config):
skip_dataloader=config.ckpt.skip_dataloader,
data_path=config.ckpt.data_path,
)
if config.train.log_model_hash:
logger.info(f"model hash: {get_module_signature(model)}")
logger.info(f"optimizer hash: {get_optimizer_signature(inner_optimizer)}")

if config.diloco is not None:
logger.info(f"outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.info(f"outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")

if world_info.rank == 0:
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
metric_logger = logger_cls(
project=config.project,
config={"config": config.model_dump(), "world_info": world_info.json()},
resume=config.wandb_resume,
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume"
)

if config.train.memory_monitor:
gpu_mem_monitor = GPUMemoryMonitor()
if config.train.memory_profiler is not None:
memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir)

Expand Down Expand Up @@ -305,15 +338,6 @@ def train(config: Config):
maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to()
if maybe_dest_rank is not None:
logger.info(f"Start live recovery to rank {maybe_dest_rank}")
if config.train.log_model_hash:
logger.info(
f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}"
)
logger.info(
f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}"
)
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}")

ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True)

elastic_device_mesh.live_recovery.reset()
Expand All @@ -328,13 +352,15 @@ def train(config: Config):

ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg)

if config.train.log_model_hash:
logger.info(
f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}"
)
logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}")

log_hash_training_state(
config,
model,
inner_optimizer,
diloco,
metric_logger,
step=training_progress.step,
id="live_reco_recv",
)
need_live_recovery = False

if config.ckpt.remote_data_load:
Expand Down Expand Up @@ -416,7 +442,6 @@ def train(config: Config):
# we count the total tokens with respect to all diloco workers
# might need to tweak this as some worker might fail to join the all reduce later
training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size()
remaining_cpu_ram = psutil.virtual_memory().available / (1024 * 1024 * 1024)

metrics = {
"Loss": loss_batch.item(),
Expand All @@ -425,16 +450,11 @@ def train(config: Config):
"Perplexity": torch.exp(loss_batch).item(),
"total_tokens": training_progress.total_tokens,
"time": time.time(),
"remaining_cpu_ram": remaining_cpu_ram,
}

if config.optim.z_loss:
metrics["z_loss"] = z_loss_batch.item()

if config.train.memory_monitor:
peak_gpu_stats = gpu_mem_monitor.get_peak_stats()
metrics.update(peak_gpu_stats)

log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}"

tokens_per_second = perf_counter.get_tokens_per_second()
Expand All @@ -461,21 +481,16 @@ def train(config: Config):
memory_profiler.step()

if config.diloco is not None:
if config.train.log_model_hash:
logger.debug("Pre diloco model: %s", get_module_signature(model))

if world_info.rank == 0 and config.monitor is not None:
monitor.set_stage("outer_loop")

time_start_inner = time.perf_counter()
diloco.step(model=model, flag=training_progress.outer_step)
diloco_time = time.perf_counter() - time_start_inner

if config.train.log_model_hash:
logger.debug("inner diloco model: %s", get_module_signature(model))
logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.debug(f"outer diloco model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step"
)

training_progress.outer_step += 1

Expand All @@ -488,13 +503,9 @@ def train(config: Config):

do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0
ckpt_manager.save(remote=do_remote)
if config.train.log_model_hash:
logger.debug("Post saved model: %s", get_module_signature(model))
logger.debug("Post saved optimizer: %s", get_optimizer_signature(inner_optimizer))

if config.diloco is not None:
logger.debug("Post saved outer model: %s", get_tensor_list_signature(diloco.param_list_cpu))
logger.debug("optimizer hash: %s", get_optimizer_signature(diloco.outer_optimizer))
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save"
)

if config.diloco:
tokens_per_second = (
Expand All @@ -517,9 +528,6 @@ def train(config: Config):
}
)

if config.train.memory_monitor:
logger.info(f"outer step peak gpu stats: {gpu_mem_monitor.format_peak_states()}")

if training_progress.step >= config.optim.total_steps:
# we only allow to break outisde of the inner loop.
# This avoid ending the training in the middle of a the inner loop
Expand Down
Loading
Loading