Skip to content

Commit

Permalink
use shared function for log modelhash
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 19, 2024
1 parent 3387a27 commit 79e9f17
Showing 1 changed file with 49 additions and 47 deletions.
96 changes: 49 additions & 47 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from zeroband.utils.activation_ckpt import apply_ac_ckpt
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader, DataConfig
from zeroband.utils.metric_logger import WandbMetricLogger, DummyMetricLogger
from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger
from zeroband.utils.monitor import HttpMonitor
from zeroband.models.llama import get_model
from zeroband.utils.profiler import MemoryProfiler
Expand Down Expand Up @@ -127,6 +127,40 @@ def validate_live_recovery_rank_src(self):
return self


def log_hash_training_state(
config: Config,
model: torch.nn.Module,
inner_optimizer: torch.optim.Optimizer,
diloco: Diloco | None,
metric_logger: MetricLogger,
id: str = "",
):
"""Log the hash of the model and optimizer. This function is slow"""
if config.train.log_model_hash:
inner_model_hash = get_module_signature(model)
inner_optimizer_hash = get_optimizer_signature(inner_optimizer)

logger.debug(f"inner diloco model {id} : {inner_model_hash}")
logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}")

if world_info.rank == 0:
metric_logger.log(
{"inner_model_hash_{id}": inner_model_hash, "inner_optimizer_hash_{id}": inner_optimizer_hash}
)

if config.diloco is not None and diloco is not None:
outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer)
outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu)

logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")

if world_info.rank == 0:
metric_logger.log(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)


def train(config: Config):
# batch_size is the total batch size for all GPUs
assert config.optim.batch_size % world_info.local_world_size == 0
Expand Down Expand Up @@ -244,6 +278,16 @@ def train(config: Config):
diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None,
)

if world_info.rank == 0:
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
metric_logger = logger_cls(
project=config.project,
config={"config": config.model_dump(), "world_info": world_info.json()},
resume=config.wandb_resume,
)
else:
metric_logger = None

if config.train.torch_compile:
# we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY
model = torch.compile(model)
Expand All @@ -256,21 +300,7 @@ def train(config: Config):
skip_dataloader=config.ckpt.skip_dataloader,
data_path=config.ckpt.data_path,
)
if config.train.log_model_hash:
logger.info(f"model hash: {get_module_signature(model)}")
logger.info(f"optimizer hash: {get_optimizer_signature(inner_optimizer)}")

if config.diloco is not None:
logger.info(f"outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.info(f"outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")

if world_info.rank == 0:
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
metric_logger = logger_cls(
project=config.project,
config={"config": config.model_dump(), "world_info": world_info.json()},
resume=config.wandb_resume,
)
log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="resume")

if config.train.memory_monitor:
gpu_mem_monitor = GPUMemoryMonitor()
Expand Down Expand Up @@ -305,15 +335,6 @@ def train(config: Config):
maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to()
if maybe_dest_rank is not None:
logger.info(f"Start live recovery to rank {maybe_dest_rank}")
if config.train.log_model_hash:
logger.info(
f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}"
)
logger.info(
f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}"
)
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}")

ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True)

elastic_device_mesh.live_recovery.reset()
Expand All @@ -328,13 +349,7 @@ def train(config: Config):

ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg)

if config.train.log_model_hash:
logger.info(
f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}"
)
logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}")

log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="live_reco_recv")
need_live_recovery = False

if config.ckpt.remote_data_load:
Expand Down Expand Up @@ -461,21 +476,14 @@ def train(config: Config):
memory_profiler.step()

if config.diloco is not None:
if config.train.log_model_hash:
logger.debug("Pre diloco model: %s", get_module_signature(model))

if world_info.rank == 0 and config.monitor is not None:
monitor.set_stage("outer_loop")

time_start_inner = time.perf_counter()
diloco.step(model=model, flag=training_progress.outer_step)
diloco_time = time.perf_counter() - time_start_inner

if config.train.log_model_hash:
logger.debug("inner diloco model: %s", get_module_signature(model))
logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.debug(f"outer diloco model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")
log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="outer_step")

training_progress.outer_step += 1

Expand All @@ -488,13 +496,7 @@ def train(config: Config):

do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0
ckpt_manager.save(remote=do_remote)
if config.train.log_model_hash:
logger.debug("Post saved model: %s", get_module_signature(model))
logger.debug("Post saved optimizer: %s", get_optimizer_signature(inner_optimizer))

if config.diloco is not None:
logger.debug("Post saved outer model: %s", get_tensor_list_signature(diloco.param_list_cpu))
logger.debug("optimizer hash: %s", get_optimizer_signature(diloco.outer_optimizer))
log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="ckpt save")

if config.diloco:
tokens_per_second = (
Expand Down

0 comments on commit 79e9f17

Please sign in to comment.