-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
clean up and add ckpt tests #179
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
from typing import Literal | ||
import time | ||
import warnings | ||
import psutil | ||
from pydantic import model_validator | ||
from multiprocessing.process import _children | ||
|
||
|
@@ -24,15 +23,14 @@ | |
|
||
from zeroband.utils import ( | ||
FakeTokenizer, | ||
GPUMemoryMonitor, | ||
PerfCounter, | ||
get_module_signature, | ||
get_optimizer_signature, | ||
get_tensor_list_signature, | ||
) | ||
from zeroband.utils.activation_ckpt import apply_ac_ckpt | ||
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader, DataConfig | ||
from zeroband.utils.metric_logger import WandbMetricLogger, DummyMetricLogger | ||
from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger | ||
from zeroband.utils.monitor import HttpMonitor | ||
from zeroband.models.llama import get_model | ||
from zeroband.utils.profiler import MemoryProfiler | ||
|
@@ -73,12 +71,13 @@ class TrainConfig(BaseConfig): | |
|
||
log_model_hash: bool = False | ||
|
||
memory_monitor: bool = False | ||
memory_profiler: MemoryProfilerConfig | None = None | ||
|
||
sequence_packing: bool = True | ||
attn_fn: Literal["flash", "sdpa"] | None = None | ||
|
||
math_attn: bool = False # slow | ||
|
||
@model_validator(mode="after") | ||
def validate_attn_fn(self): | ||
if self.attn_fn is not None: | ||
|
@@ -127,6 +126,43 @@ def validate_live_recovery_rank_src(self): | |
return self | ||
|
||
|
||
def log_hash_training_state( | ||
config: Config, | ||
model: torch.nn.Module, | ||
inner_optimizer: torch.optim.Optimizer, | ||
diloco: Diloco | None, | ||
metric_logger: MetricLogger, | ||
step: int, | ||
id: str = "", | ||
): | ||
"""Log the hash of the model and optimizer. This function is slow""" | ||
if config.train.log_model_hash: | ||
inner_model_hash = get_module_signature(model) | ||
inner_optimizer_hash = get_optimizer_signature(inner_optimizer) | ||
|
||
logger.debug(f"inner diloco model {id} : {inner_model_hash}") | ||
logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}") | ||
|
||
metrics = { | ||
"step": step, | ||
f"inner_model_hash_{id}": inner_model_hash, | ||
f"inner_optimizer_hash_{id}": inner_optimizer_hash, | ||
} | ||
|
||
if config.diloco is not None and diloco is not None: | ||
outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer) | ||
outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) | ||
|
||
logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") | ||
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") | ||
|
||
metrics.update( | ||
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} | ||
) | ||
if world_info.rank == 0: | ||
metric_logger.log(metrics) | ||
|
||
|
||
def train(config: Config): | ||
# batch_size is the total batch size for all GPUs | ||
assert config.optim.batch_size % world_info.local_world_size == 0 | ||
|
@@ -164,6 +200,7 @@ def train(config: Config): | |
config.type_model, | ||
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, | ||
seq_length=config.data.seq_length, | ||
math_attn=config.train.math_attn, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about passing attn_fn instead? Would also allow sdpa to be specified |
||
) | ||
|
||
model = model.to(world_info.local_rank) | ||
|
@@ -244,6 +281,16 @@ def train(config: Config): | |
diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, | ||
) | ||
|
||
if world_info.rank == 0: | ||
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger | ||
metric_logger = logger_cls( | ||
project=config.project, | ||
config={"config": config.model_dump(), "world_info": world_info.json()}, | ||
resume=config.wandb_resume, | ||
) | ||
else: | ||
metric_logger = None | ||
|
||
if config.train.torch_compile: | ||
# we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY | ||
model = torch.compile(model) | ||
|
@@ -256,24 +303,10 @@ def train(config: Config): | |
skip_dataloader=config.ckpt.skip_dataloader, | ||
data_path=config.ckpt.data_path, | ||
) | ||
if config.train.log_model_hash: | ||
logger.info(f"model hash: {get_module_signature(model)}") | ||
logger.info(f"optimizer hash: {get_optimizer_signature(inner_optimizer)}") | ||
|
||
if config.diloco is not None: | ||
logger.info(f"outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}") | ||
logger.info(f"outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") | ||
|
||
if world_info.rank == 0: | ||
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger | ||
metric_logger = logger_cls( | ||
project=config.project, | ||
config={"config": config.model_dump(), "world_info": world_info.json()}, | ||
resume=config.wandb_resume, | ||
log_hash_training_state( | ||
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" | ||
) | ||
|
||
if config.train.memory_monitor: | ||
gpu_mem_monitor = GPUMemoryMonitor() | ||
if config.train.memory_profiler is not None: | ||
memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir) | ||
|
||
|
@@ -305,15 +338,6 @@ def train(config: Config): | |
maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to() | ||
if maybe_dest_rank is not None: | ||
logger.info(f"Start live recovery to rank {maybe_dest_rank}") | ||
if config.train.log_model_hash: | ||
logger.info( | ||
f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" | ||
) | ||
logger.info( | ||
f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}" | ||
) | ||
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") | ||
|
||
ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) | ||
|
||
elastic_device_mesh.live_recovery.reset() | ||
|
@@ -328,13 +352,15 @@ def train(config: Config): | |
|
||
ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) | ||
|
||
if config.train.log_model_hash: | ||
logger.info( | ||
f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" | ||
) | ||
logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") | ||
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") | ||
|
||
log_hash_training_state( | ||
config, | ||
model, | ||
inner_optimizer, | ||
diloco, | ||
metric_logger, | ||
step=training_progress.step, | ||
id="live_reco_recv", | ||
) | ||
need_live_recovery = False | ||
|
||
if config.ckpt.remote_data_load: | ||
|
@@ -416,7 +442,6 @@ def train(config: Config): | |
# we count the total tokens with respect to all diloco workers | ||
# might need to tweak this as some worker might fail to join the all reduce later | ||
training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size() | ||
remaining_cpu_ram = psutil.virtual_memory().available / (1024 * 1024 * 1024) | ||
|
||
metrics = { | ||
"Loss": loss_batch.item(), | ||
|
@@ -425,16 +450,11 @@ def train(config: Config): | |
"Perplexity": torch.exp(loss_batch).item(), | ||
"total_tokens": training_progress.total_tokens, | ||
"time": time.time(), | ||
"remaining_cpu_ram": remaining_cpu_ram, | ||
} | ||
|
||
if config.optim.z_loss: | ||
metrics["z_loss"] = z_loss_batch.item() | ||
|
||
if config.train.memory_monitor: | ||
peak_gpu_stats = gpu_mem_monitor.get_peak_stats() | ||
metrics.update(peak_gpu_stats) | ||
|
||
log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}" | ||
|
||
tokens_per_second = perf_counter.get_tokens_per_second() | ||
|
@@ -461,21 +481,16 @@ def train(config: Config): | |
memory_profiler.step() | ||
|
||
if config.diloco is not None: | ||
if config.train.log_model_hash: | ||
logger.debug("Pre diloco model: %s", get_module_signature(model)) | ||
|
||
if world_info.rank == 0 and config.monitor is not None: | ||
monitor.set_stage("outer_loop") | ||
|
||
time_start_inner = time.perf_counter() | ||
diloco.step(model=model, flag=training_progress.outer_step) | ||
diloco_time = time.perf_counter() - time_start_inner | ||
|
||
if config.train.log_model_hash: | ||
logger.debug("inner diloco model: %s", get_module_signature(model)) | ||
logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}") | ||
logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}") | ||
logger.debug(f"outer diloco model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") | ||
log_hash_training_state( | ||
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step" | ||
) | ||
|
||
training_progress.outer_step += 1 | ||
|
||
|
@@ -488,13 +503,9 @@ def train(config: Config): | |
|
||
do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 | ||
ckpt_manager.save(remote=do_remote) | ||
if config.train.log_model_hash: | ||
logger.debug("Post saved model: %s", get_module_signature(model)) | ||
logger.debug("Post saved optimizer: %s", get_optimizer_signature(inner_optimizer)) | ||
|
||
if config.diloco is not None: | ||
logger.debug("Post saved outer model: %s", get_tensor_list_signature(diloco.param_list_cpu)) | ||
logger.debug("optimizer hash: %s", get_optimizer_signature(diloco.outer_optimizer)) | ||
log_hash_training_state( | ||
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save" | ||
) | ||
|
||
if config.diloco: | ||
tokens_per_second = ( | ||
|
@@ -517,9 +528,6 @@ def train(config: Config): | |
} | ||
) | ||
|
||
if config.train.memory_monitor: | ||
logger.info(f"outer step peak gpu stats: {gpu_mem_monitor.format_peak_states()}") | ||
|
||
if training_progress.step >= config.optim.total_steps: | ||
# we only allow to break outisde of the inner loop. | ||
# This avoid ending the training in the middle of a the inner loop | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about putting this as an option in
attn_fn
instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm attn_fn is not used anymore. I just kept it to avoid conflict with old code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pr to remove attn_fn #180