diff --git a/dolomite_engine/checkpointing.py b/dolomite_engine/checkpointing.py index ade1f3c5..0286d966 100644 --- a/dolomite_engine/checkpointing.py +++ b/dolomite_engine/checkpointing.py @@ -39,7 +39,6 @@ def ensure_last_checkpoint_is_saved() -> None: - global _FUTURE if _FUTURE is not None: _FUTURE.result() @@ -194,6 +193,7 @@ def save_checkpoint( save_args(args, save_path, mode=Mode.training) + global _FUTURE _FUTURE = dcp.async_save( { "state": _Saver( diff --git a/dolomite_engine/finetune.py b/dolomite_engine/finetune.py index 24f9e33c..24d1fec5 100644 --- a/dolomite_engine/finetune.py +++ b/dolomite_engine/finetune.py @@ -1,8 +1,9 @@ -from contextlib import nullcontext +from contextlib import AbstractContextManager, nullcontext import torch -from torch.distributed.pipelining.schedules import _PipelineSchedule from torch.distributed.tensor.parallel import loss_parallel +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR from transformers import set_seed from .arguments import TrainingArgs, get_args @@ -11,16 +12,127 @@ from .data import ResumableDataLoader, custom_iterator, get_finetuning_dataloader, get_next_batch from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training from .enums import DatasetSplit, Mode, TuningMethod -from .model_wrapper import get_model_container +from .model_wrapper import ModelWrapper, get_model_container from .optimization import get_optimizer_container, get_scheduler_container -from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics, train_step -from .utils import ExperimentsTracker, MetricsTrackingDict, ProcessGroupManager, init_distributed, setup_tf32 +from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics +from .utils import ( + ExperimentsTracker, + MetricsTrackingDict, + ProcessGroupManager, + StepTracker, + init_distributed, + is_torchao_available, + setup_tf32, +) + + +if is_torchao_available(): + from .distributed import FP8Manager + + +def train_step_without_pipeline_parallel( + model: ModelWrapper, + optimizer: Optimizer, + lr_scheduler: LambdaLR, + train_dataloader: ResumableDataLoader, + gradient_clipping: float, + forward_context: AbstractContextManager, + backward_context: AbstractContextManager, + sync_every_gradient_accumulation_step: bool, +) -> MetricsTrackingDict: + """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary + + Args: + model (ModelWrapper): model + optimizer (Optimizer): optimizer + lr_scheduler (LamdaLR): learning rate scheduler + train_dataloader (ResumableDataLoader): training dataloader + gradient_accumulation_steps (int): gradient accumulation steps + gradient_clipping (float): gradient clipping value + forward_context (AbstractContextManager): a context that is used for every model forward call + backward_context (AbstractContextManager): a context that is used for every model backward call + sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step + + Returns: + MetricsTrackingDict: metrics to track + """ + + fsdp_algorithm = 2 if hasattr(model, "set_requires_gradient_sync") else 1 + + no_sync = nullcontext + if not sync_every_gradient_accumulation_step: + if fsdp_algorithm == 1: + no_sync = model.no_sync + else: + model.set_requires_gradient_sync(False) + + metrics_tracker = MetricsTrackingDict({}) + grad_norm = None + optimizer.zero_grad() + + gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps() + + # note the effect of gradient accumulation division is already in the lm_loss_multiplier + batches = [get_next_batch(train_dataloader) for _ in range(gradient_accumulation_steps)] + lm_loss_multiplier = 1 / sum([(batch["labels"] != -100).sum() for batch in batches]) + + with no_sync(): + for batch in batches[:-1]: + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + + # compute gradients + with backward_context(): + loss_micro_step_dict["loss"].backward() + + with torch.inference_mode(): + metrics_tracker = metrics_tracker + loss_micro_step_dict + + if fsdp_algorithm == 2: + model.set_requires_gradient_sync(True) + + batch = batches[-1] + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + + # compute gradients + with backward_context(): + loss_micro_step_dict["loss"].backward() + + with torch.inference_mode(): + metrics_tracker = metrics_tracker + loss_micro_step_dict + + if gradient_clipping is not None: + if fsdp_algorithm == 1: + grad_norm = model.clip_grad_norm_(gradient_clipping) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) + + if is_torchao_available(): + FP8Manager.sync_float8_amax_and_scale_history([model]) + + optimizer.step() + lr_scheduler.step() + + if is_torchao_available(): + FP8Manager.precompute_float8_dynamic_scale_for_fsdp([model]) + + with torch.inference_mode(): + metrics_tracker["grad_norm"] = ( + torch.tensor(0, device=torch.cuda.current_device()) if grad_norm is None else grad_norm + ) + + for key in metrics_tracker: + metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) + + metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) + + return metrics_tracker def train( args: TrainingArgs, model_container: ModelContainer, - pipeline_schedule: _PipelineSchedule, optimizer_container: OptimizerContainer, lr_scheduler_container: LRSchedulerContainer, train_dataloader: ResumableDataLoader, @@ -43,7 +155,6 @@ def train( """ num_training_steps = args.training_parameters.num_training_steps - gradient_accumulation_steps = args.training_parameters.gradient_accumulation_steps gradient_clipping = args.training_parameters.gradient_clipping eval_during_training = args.training_parameters.eval_during_training @@ -73,20 +184,15 @@ def train( while global_step < num_training_steps: global_step += 1 - loss_step_dict = train_step( - model_container=model_container, - pipeline_schedule=pipeline_schedule, - optimizer_container=optimizer_container, - lr_scheduler_container=lr_scheduler_container, + loss_step_dict = train_step_without_pipeline_parallel( + model=model_container[0], + optimizer=optimizer_container[0], + lr_scheduler=lr_scheduler_container[0], train_dataloader=train_dataloader_infinite, - gradient_accumulation_steps=gradient_accumulation_steps, gradient_clipping=gradient_clipping, forward_context=forward_context, backward_context=backward_context, sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, - is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1, - local_batch_size=None, - sequence_length=None, ) metrics_tracker = metrics_tracker + loss_step_dict @@ -124,7 +230,7 @@ def train( ensure_last_checkpoint_is_saved() if torch_profiler is not None: - torch_profiler.__exit__() + torch_profiler.__exit__(None, None, None) @torch.no_grad() @@ -165,13 +271,15 @@ def evaluate( metrics_tracker = MetricsTrackingDict({}) val_dataloader = custom_iterator(val_dataloader, infinite=False) + loss_tokens = 0 for _ in range(num_steps): batch = get_next_batch(val_dataloader) + loss_tokens += (batch["labels"] != -100).sum() loss_step_dict = model_container[0](batch) metrics_tracker = metrics_tracker + loss_step_dict - metrics_tracker = metrics_tracker / num_steps + metrics_tracker = metrics_tracker / loss_tokens.item() for key in metrics_tracker: metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) @@ -193,8 +301,6 @@ def evaluate( def main() -> None: """main program""" - assert False - mode = Mode.training setup_tf32() @@ -217,6 +323,11 @@ def main() -> None: use_async_tensor_parallel=args.distributed_args.use_async_tensor_parallel, ) + StepTracker( + micro_batch_size=args.training_parameters.micro_batch_size, + gradient_accumulation_steps=args.training_parameters.gradient_accumulation_steps, + ) + set_seed(args.random_args.seed) assert args.distributed_args.num_pipeline_stages == 1, "pipeline parallel is not supported with finetuning" @@ -241,7 +352,7 @@ def main() -> None: is_encoder_decoder=model_container[0].is_encoder_decoder, ) - model_container, pipeline_schedule = wrap_model_container_for_distributed_training(args, model_container) + model_container, _ = wrap_model_container_for_distributed_training(args, model_container) optimizer_container = get_optimizer_container( optimizer_class_name=args.optimizer_args.class_name, @@ -261,6 +372,9 @@ def main() -> None: extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args, ) + assert len(model_container) == len(optimizer_container) + assert len(optimizer_container) == len(lr_scheduler_container) + log_model_optimizer_container(model_container, optimizer_container) starting_iteration = 0 @@ -283,7 +397,6 @@ def main() -> None: train( args, model_container=model_container, - pipeline_schedule=pipeline_schedule, optimizer_container=optimizer_container, lr_scheduler_container=lr_scheduler_container, train_dataloader=train_dataloader, diff --git a/dolomite_engine/model_wrapper/finetuning.py b/dolomite_engine/model_wrapper/finetuning.py index 6d043508..e227a7fb 100644 --- a/dolomite_engine/model_wrapper/finetuning.py +++ b/dolomite_engine/model_wrapper/finetuning.py @@ -1,14 +1,16 @@ import torch import torch.distributed +from torch.distributed._tensor.placement_types import Replicate from ..communication import Communication +from ..distributed import tensor_to_dtensor from ..hf_models import get_autoregressive_language_modeling_loss from ..utils import MetricsTrackingDict, ProcessGroupManager from .base import ModelWrapper class ModelWrapperForFinetuning(ModelWrapper): - def forward(self, batch: dict) -> MetricsTrackingDict: + def forward(self, batch: dict, lm_loss_multiplier: float = 1) -> MetricsTrackingDict: """forward function for a batch Args: @@ -25,17 +27,42 @@ def forward(self, batch: dict) -> MetricsTrackingDict: model_outputs = self.model(**batch) - loss = get_autoregressive_language_modeling_loss( - lm_logits=model_outputs.logits, + return self.get_loss( + model_outputs=model_outputs, labels=labels, - upcast_logits_for_loss=self.upcast_logits_for_loss, cu_seqlens=batch.get("cu_seqlens", None), + lm_loss_multiplier=lm_loss_multiplier, + ) + + def get_loss( + self, model_outputs, labels: torch.Tensor, cu_seqlens: torch.Tensor | None, lm_loss_multiplier: float = 1 + ) -> torch.Tensor | dict: + logits: torch.Tensor = model_outputs.logits + aux_loss = model_outputs.aux_loss if hasattr(model_outputs, "aux_loss") else None + + lm_loss = get_autoregressive_language_modeling_loss( + lm_logits=logits, + labels=labels, + upcast_logits_for_loss=self.upcast_logits_for_loss, + cu_seqlens=cu_seqlens, use_padding_free_transformer=self.use_padding_free_transformer, reduction="sum", tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, ) - return MetricsTrackingDict({"loss": loss}) + lm_loss = lm_loss * lm_loss_multiplier + + if aux_loss is None: + loss = lm_loss + output = {"loss": loss} + else: + if ProcessGroupManager.is_tensor_parallel_enabled(): + aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate()) + + loss = lm_loss + self.router_aux_loss_coef * aux_loss + output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss} + + return output def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict: device = torch.cuda.current_device() diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index 9bf84e75..e2a4ab4b 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -120,7 +120,7 @@ def forward(self, batch: dict, prev_aux_loss: torch.Tensor | None = None, lm_los return output - def get_loss(self, model_outputs, labels: torch.Tensor, lm_loss_multiplier: float = 1) -> torch.Tensor: + def get_loss(self, model_outputs, labels: torch.Tensor, lm_loss_multiplier: float = 1) -> torch.Tensor | dict: if isinstance(model_outputs, torch.Tensor): logits = model_outputs aux_loss = None diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index 24301ed2..d465bcf6 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -1,10 +1,12 @@ import logging import time -from contextlib import nullcontext +from contextlib import AbstractContextManager, nullcontext import torch from torch.distributed.pipelining.schedules import _PipelineSchedule from torch.distributed.tensor.parallel import loss_parallel +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from transformers import set_seed @@ -12,23 +14,223 @@ from .checkpointing import ensure_last_checkpoint_is_saved, load_checkpoint_for_training, save_checkpoint from .communication import Communication from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container -from .data import get_next_batch, get_pretraining_dataloaders +from .data import ResumableDataLoader, get_next_batch, get_pretraining_dataloaders from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training from .enums import Mode, TuningMethod -from .model_wrapper import get_model_container +from .model_wrapper import ModelWrapper, get_model_container from .optimization import get_optimizer_container, get_scheduler_container -from .train_utils import all_reduce_metrics_tracker, get_model_tflops, get_torch_profiler, track_metrics, train_step +from .train_utils import all_reduce_metrics_tracker, get_model_tflops, get_torch_profiler, track_metrics from .utils import ( ExperimentsTracker, MetricsTrackingDict, ProcessGroupManager, StepTracker, init_distributed, + is_torchao_available, log_rank_0, setup_tf32, ) +if is_torchao_available(): + from .distributed import FP8Manager + + +def train_step_with_pipeline_parallel( + model_container: ModelContainer, + pipeline_schedule: _PipelineSchedule, + optimizer_container: OptimizerContainer, + lr_scheduler_container: LRSchedulerContainer, + train_dataloader: ResumableDataLoader, + gradient_clipping: float, + sequence_length: int, +) -> MetricsTrackingDict: + """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary + + Args: + model_container (ModelContainer): container of models + pipeline_schedule (_PipelineSchedule): pipeline schedule + optimizer_container (OptimizerContainer): container of optimizers + lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers + train_dataloader (ResumableDataLoader): training dataloader + gradient_clipping (float): gradient clipping value + sequence_length (int): sequence length + + Returns: + MetricsTrackingDict: metrics to track + """ + + fsdp_algorithm = 2 if hasattr(model_container[0], "set_requires_gradient_sync") else 1 + grad_norm = [] + + optimizer_container.zero_grad() + + batch = get_next_batch(train_dataloader) + + if ProcessGroupManager.is_tensor_parallel_first_rank(): + batch = batch["text"] + + batch = model_container[0].broadcast_tensor_parallel_input( + batch, (StepTracker.get_local_batch_size(), sequence_length + 1) + ) + + is_first_pipeline_rank = ProcessGroupManager.get_pipeline_parallel_rank() == 0 + is_last_pipeline_rank = ( + ProcessGroupManager.get_pipeline_parallel_rank() == ProcessGroupManager.get_pipeline_parallel_world_size() - 1 + ) + + if is_first_pipeline_rank: + pipeline_schedule.step(batch) + elif is_last_pipeline_rank: + losses = [] + labels = batch[:, 1:] + pipeline_schedule.step(target=labels, losses=losses) + else: + pipeline_schedule.step() + + if gradient_clipping is not None: + for model in model_container: + if fsdp_algorithm == 1: + grad_norm.append(model.clip_grad_norm_(gradient_clipping)) + else: + grad_norm.append(torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)) + + if is_torchao_available(): + FP8Manager.sync_float8_amax_and_scale_history(model_container) + + optimizer_container.step() + lr_scheduler_container.step() + + if is_torchao_available(): + FP8Manager.precompute_float8_dynamic_scale_for_fsdp(model_container) + + metrics_tracker = MetricsTrackingDict({}) + + with torch.inference_mode(): + grad_norm = dtensor_to_tensor(sum(grad_norm)) + torch.distributed.all_reduce(grad_norm, group=ProcessGroupManager.get_pipeline_parallel_group()) + + if is_last_pipeline_rank: + losses = sum(losses) + + metrics_tracker = metrics_tracker + {"loss": losses, "grad_norm": grad_norm} + metrics_tracker = metrics_tracker + model.get_extra_metrics() + model.reset_extra_metrics() + + metrics_tracker = metrics_tracker / StepTracker.get_gradient_accumulation_steps() + + metrics_tracker["grad_norm"] = grad_norm + + for key in metrics_tracker: + metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) + + metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) + + return metrics_tracker + + +def train_step_without_pipeline_parallel( + model: ModelWrapper, + optimizer: Optimizer, + lr_scheduler: LambdaLR, + train_dataloader: ResumableDataLoader, + gradient_clipping: float, + forward_context: AbstractContextManager, + backward_context: AbstractContextManager, + sync_every_gradient_accumulation_step: bool, + lm_loss_multiplier: float, +) -> MetricsTrackingDict: + """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary + + Args: + model (ModelWrapper): model + optimizer (Optimizer): optimizer + lr_scheduler (LamdaLR): learning rate scheduler + train_dataloader (ResumableDataLoader): training dataloader + gradient_clipping (float): gradient clipping value + forward_context (AbstractContextManager): a context that is used for every model forward call + backward_context (AbstractContextManager): a context that is used for every model backward call + sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step + lm_loss_multiplier (int): lm loss multiplier + + Returns: + MetricsTrackingDict: metrics to track + """ + + fsdp_algorithm = 2 if hasattr(model, "set_requires_gradient_sync") else 1 + + no_sync = nullcontext + if not sync_every_gradient_accumulation_step: + if fsdp_algorithm == 1: + no_sync = model.no_sync + else: + model.set_requires_gradient_sync(False) + + metrics_tracker = MetricsTrackingDict({}) + grad_norm = None + optimizer.zero_grad() + + gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps() + + with no_sync(): + for _ in range(gradient_accumulation_steps - 1): + batch = get_next_batch(train_dataloader) + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + + # compute gradients + with backward_context(): + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() + + with torch.inference_mode(): + metrics_tracker = metrics_tracker + loss_micro_step_dict + + if fsdp_algorithm == 2: + model.set_requires_gradient_sync(True) + + batch = get_next_batch(train_dataloader) + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + + # compute gradients + with backward_context(): + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() + + with torch.inference_mode(): + metrics_tracker = metrics_tracker + loss_micro_step_dict + + if gradient_clipping is not None: + if fsdp_algorithm == 1: + grad_norm = model.clip_grad_norm_(gradient_clipping) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) + + if is_torchao_available(): + FP8Manager.sync_float8_amax_and_scale_history([model]) + + optimizer.step() + lr_scheduler.step() + + if is_torchao_available(): + FP8Manager.precompute_float8_dynamic_scale_for_fsdp([model]) + + with torch.inference_mode(): + metrics_tracker = metrics_tracker / gradient_accumulation_steps + + metrics_tracker["grad_norm"] = ( + torch.tensor(0, device=torch.cuda.current_device()) if grad_norm is None else grad_norm + ) + + for key in metrics_tracker: + metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) + + metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) + + return metrics_tracker + + def track_val_metrics( global_step: int, experiments_tracker: ExperimentsTracker, @@ -93,7 +295,6 @@ def train( """ num_training_steps = args.training_parameters.num_training_steps - gradient_accumulation_steps = args.training_parameters.gradient_accumulation_steps gradient_clipping = args.training_parameters.gradient_clipping eval_during_training = args.training_parameters.eval_during_training @@ -112,14 +313,15 @@ def train( eval_steps = args.datasets[0].class_args.get("eval_steps") evaluate(val_dataloaders, model_container, starting_iteration, experiments_tracker, eval_steps, group_names) - dp_world_size = ProcessGroupManager.get_data_parallel_world_size() - micro_batch_size = args.training_parameters.micro_batch_size sequence_length = args.datasets[0].class_args.get("sequence_length") - local_batch_size = micro_batch_size * gradient_accumulation_steps - global_batch_size = local_batch_size * dp_world_size + global_batch_size = StepTracker.get_global_batch_size() tokens_per_batch = global_batch_size * sequence_length + is_pipeline_parallel_enabled = args.distributed_args.num_pipeline_stages > 1 + if not is_pipeline_parallel_enabled: + assert len(model_container) == 1 + # model flops per GPU model_flops = ( get_model_tflops( @@ -149,22 +351,28 @@ def train( global_step += 1 steps_since_start_time += 1 - loss_step_dict = train_step( - model_container=model_container, - pipeline_schedule=pipeline_schedule, - optimizer_container=optimizer_container, - lr_scheduler_container=lr_scheduler_container, - train_dataloader=train_dataloader, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_clipping=gradient_clipping, - forward_context=forward_context, - backward_context=backward_context, - sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, - is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1, - local_batch_size=local_batch_size, - micro_batch_size=micro_batch_size, - sequence_length=sequence_length, - ) + if is_pipeline_parallel_enabled: + loss_step_dict = train_step_with_pipeline_parallel( + model_container=model_container, + pipeline_schedule=pipeline_schedule, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, + train_dataloader=train_dataloader, + gradient_clipping=gradient_clipping, + sequence_length=sequence_length, + ) + else: + loss_step_dict = train_step_without_pipeline_parallel( + model=model_container[0], + optimizer=optimizer_container[0], + lr_scheduler=lr_scheduler_container[0], + train_dataloader=train_dataloader, + gradient_clipping=gradient_clipping, + forward_context=forward_context, + backward_context=backward_context, + sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, + lm_loss_multiplier=1 / (micro_batch_size * sequence_length), + ) metrics_tracker = metrics_tracker + loss_step_dict @@ -208,9 +416,7 @@ def train( train_dataloader=None, experiments_tracker=experiments_tracker, iteration=global_step, - metadata={ - "consumed_samples": global_step * micro_batch_size * gradient_accumulation_steps * dp_world_size - }, + metadata={"consumed_samples": global_step * global_batch_size}, ) start_time = time.perf_counter() @@ -222,7 +428,7 @@ def train( ensure_last_checkpoint_is_saved() if torch_profiler is not None: - torch_profiler.__exit__() + torch_profiler.__exit__(None, None, None) @torch.no_grad() @@ -359,6 +565,9 @@ def main(mode: Mode = Mode.training) -> None: extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args, ) + assert len(model_container) == len(optimizer_container) + assert len(optimizer_container) == len(lr_scheduler_container) + log_model_optimizer_container(model_container, optimizer_container) starting_iteration = 0 diff --git a/dolomite_engine/train_utils.py b/dolomite_engine/train_utils.py index b3e96ab8..7e1923e2 100644 --- a/dolomite_engine/train_utils.py +++ b/dolomite_engine/train_utils.py @@ -1,298 +1,13 @@ import logging -from contextlib import AbstractContextManager, nullcontext import torch from torch.distributed import ReduceOp -from torch.distributed.pipelining.schedules import _PipelineSchedule -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR from transformers import AutoConfig -from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer -from .data import ResumableDataLoader, get_next_batch -from .distributed import dtensor_to_tensor from .enums import GradientCheckpointingMethod from .hf_models import is_custom_model from .hf_models.modeling_utils import is_glu -from .model_wrapper import ModelWrapper -from .utils import ExperimentsTracker, MetricsTrackingDict, ProcessGroupManager, is_torchao_available, log_metrics - - -if is_torchao_available(): - from .distributed import FP8Manager - - -def train_step( - model_container: ModelContainer, - pipeline_schedule: _PipelineSchedule, - optimizer_container: OptimizerContainer, - lr_scheduler_container: LRSchedulerContainer, - train_dataloader: ResumableDataLoader, - gradient_accumulation_steps: int, - gradient_clipping: float, - forward_context: AbstractContextManager, - backward_context: AbstractContextManager, - sync_every_gradient_accumulation_step: bool, - is_pipeline_parallel_enabled: bool, - local_batch_size: int, - micro_batch_size: int, - sequence_length: int, -) -> MetricsTrackingDict: - """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary - - Args: - model_container (ModelContainer): container of models - pipeline_schedule (_PipelineSchedule): pipeline schedule - optimizer_container (OptimizerContainer): container of optimizers - lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers - train_dataloader (ResumableDataLoader): training dataloader - gradient_accumulation_steps (int): gradient accumulation steps - gradient_clipping (float): gradient clipping value - forward_context (AbstractContextManager): a context that is used for every model forward call - backward_context (AbstractContextManager): a context that is used for every model backward call - sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step - is_pipeline_parallel_enabled (bool): whether to use pipeline parallel - local_batch_size (int): local batch size - sequence_length (int): sequence length - - Returns: - MetricsTrackingDict: metrics to track - """ - - assert len(model_container) == len(optimizer_container) - assert len(optimizer_container) == len(lr_scheduler_container) - - if is_pipeline_parallel_enabled: - metrics_tracker = _train_step_with_pipeline_parallel( - model_container=model_container, - pipeline_schedule=pipeline_schedule, - optimizer_container=optimizer_container, - lr_scheduler_container=lr_scheduler_container, - train_dataloader=train_dataloader, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_clipping=gradient_clipping, - local_batch_size=local_batch_size, - sequence_length=sequence_length, - ) - else: - assert len(model_container) == 1 - - metrics_tracker = _train_step_without_pipeline_parallel( - model=model_container[0], - optimizer=optimizer_container[0], - lr_scheduler=lr_scheduler_container[0], - train_dataloader=train_dataloader, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_clipping=gradient_clipping, - forward_context=forward_context, - backward_context=backward_context, - sync_every_gradient_accumulation_step=sync_every_gradient_accumulation_step, - micro_batch_size=micro_batch_size, - sequence_length=sequence_length, - ) - - return metrics_tracker - - -def _train_step_with_pipeline_parallel( - model_container: ModelContainer, - pipeline_schedule: _PipelineSchedule, - optimizer_container: OptimizerContainer, - lr_scheduler_container: LRSchedulerContainer, - train_dataloader: ResumableDataLoader, - gradient_accumulation_steps: int, - gradient_clipping: float, - local_batch_size: int, - sequence_length: int, -) -> MetricsTrackingDict: - """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary - - Args: - model_container (ModelContainer): container of models - pipeline_schedule (_PipelineSchedule): pipeline schedule - optimizer_container (OptimizerContainer): container of optimizers - lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers - train_dataloader (ResumableDataLoader): training dataloader - gradient_accumulation_steps (int): gradient accumulation steps - gradient_clipping (float): gradient clipping value - local_batch_size (int): local batch size - sequence_length (int): sequence length - - Returns: - MetricsTrackingDict: metrics to track - """ - - fsdp_algorithm = 2 if hasattr(model_container[0], "set_requires_gradient_sync") else 1 - grad_norm = [] - - optimizer_container.zero_grad() - - batch = get_next_batch(train_dataloader) - - if ProcessGroupManager.is_tensor_parallel_first_rank(): - batch = batch["text"] - - batch = model_container[0].broadcast_tensor_parallel_input(batch, (local_batch_size, sequence_length + 1)) - - is_first_pipeline_rank = ProcessGroupManager.get_pipeline_parallel_rank() == 0 - is_last_pipeline_rank = ( - ProcessGroupManager.get_pipeline_parallel_rank() == ProcessGroupManager.get_pipeline_parallel_world_size() - 1 - ) - - if is_first_pipeline_rank: - pipeline_schedule.step(batch) - elif is_last_pipeline_rank: - losses = [] - labels = batch[:, 1:] - pipeline_schedule.step(target=labels, losses=losses) - else: - pipeline_schedule.step() - - if gradient_clipping is not None: - for model in model_container: - if fsdp_algorithm == 1: - grad_norm.append(model.clip_grad_norm_(gradient_clipping)) - else: - grad_norm.append(torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)) - - if is_torchao_available(): - FP8Manager.sync_float8_amax_and_scale_history(model_container) - - optimizer_container.step() - lr_scheduler_container.step() - - if is_torchao_available(): - FP8Manager.precompute_float8_dynamic_scale_for_fsdp(model_container) - - metrics_tracker = MetricsTrackingDict({}) - - with torch.inference_mode(): - grad_norm = dtensor_to_tensor(sum(grad_norm)) - torch.distributed.all_reduce(grad_norm, group=ProcessGroupManager.get_pipeline_parallel_group()) - - if is_last_pipeline_rank: - losses = sum(losses) - - metrics_tracker = metrics_tracker + {"loss": losses, "grad_norm": grad_norm} - metrics_tracker = metrics_tracker + model.get_extra_metrics() - model.reset_extra_metrics() - - metrics_tracker = metrics_tracker / gradient_accumulation_steps - - metrics_tracker["grad_norm"] = grad_norm - - for key in metrics_tracker: - metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) - - metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) - - return metrics_tracker - - -def _train_step_without_pipeline_parallel( - model: ModelWrapper, - optimizer: Optimizer, - lr_scheduler: LambdaLR, - train_dataloader: ResumableDataLoader, - gradient_accumulation_steps: int, - gradient_clipping: float, - forward_context: AbstractContextManager, - backward_context: AbstractContextManager, - sync_every_gradient_accumulation_step: bool, - micro_batch_size: int, - sequence_length: int, -) -> MetricsTrackingDict: - """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary - - Args: - model (ModelWrapper): model - optimizer (Optimizer): optimizer - lr_scheduler (LamdaLR): learning rate scheduler - train_dataloader (ResumableDataLoader): training dataloader - gradient_accumulation_steps (int): gradient accumulation steps - gradient_clipping (float): gradient clipping value - forward_context (AbstractContextManager): a context that is used for every model forward call - backward_context (AbstractContextManager): a context that is used for every model backward call - sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step - micro_batch_size (int): micro batch size - sequence_length (int): sequence length - - Returns: - MetricsTrackingDict: metrics to track - """ - - fsdp_algorithm = 2 if hasattr(model, "set_requires_gradient_sync") else 1 - - no_sync = nullcontext - if not sync_every_gradient_accumulation_step: - if fsdp_algorithm == 1: - no_sync = model.no_sync - else: - model.set_requires_gradient_sync(False) - - metrics_tracker = MetricsTrackingDict({}) - grad_norm = None - optimizer.zero_grad() - - lm_loss_multiplier = 1 / (micro_batch_size * sequence_length) - - with no_sync(): - for _ in range(gradient_accumulation_steps - 1): - batch = get_next_batch(train_dataloader) - with forward_context(): - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) - - # compute gradients - with backward_context(): - loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps - loss_micro_step_scaled.backward() - - with torch.inference_mode(): - metrics_tracker = metrics_tracker + loss_micro_step_dict - - if fsdp_algorithm == 2: - model.set_requires_gradient_sync(True) - - batch = get_next_batch(train_dataloader) - with forward_context(): - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) - - # compute gradients - with backward_context(): - loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps - loss_micro_step_scaled.backward() - - with torch.inference_mode(): - metrics_tracker = metrics_tracker + loss_micro_step_dict - - if gradient_clipping is not None: - if fsdp_algorithm == 1: - grad_norm = model.clip_grad_norm_(gradient_clipping) - else: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) - - if is_torchao_available(): - FP8Manager.sync_float8_amax_and_scale_history([model]) - - optimizer.step() - lr_scheduler.step() - - if is_torchao_available(): - FP8Manager.precompute_float8_dynamic_scale_for_fsdp([model]) - - with torch.inference_mode(): - metrics_tracker = metrics_tracker / gradient_accumulation_steps - - metrics_tracker["grad_norm"] = ( - torch.tensor(0, device=torch.cuda.current_device()) if grad_norm is None else grad_norm - ) - - for key in metrics_tracker: - metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key]) - - metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) - - return metrics_tracker +from .utils import ExperimentsTracker, MetricsTrackingDict, ProcessGroupManager, log_metrics def all_reduce_metrics_tracker(metrics_tracker: MetricsTrackingDict) -> MetricsTrackingDict: diff --git a/dolomite_engine/utils/step_tracker.py b/dolomite_engine/utils/step_tracker.py index fd82e518..3450040f 100644 --- a/dolomite_engine/utils/step_tracker.py +++ b/dolomite_engine/utils/step_tracker.py @@ -14,9 +14,12 @@ def __init__(self, micro_batch_size: int, gradient_accumulation_steps: int) -> N @staticmethod def get_local_batch_size() -> int: - global _MICRO_BATCH_SIZE, _GRADIENT_ACCUMULATION_STEPS return _MICRO_BATCH_SIZE * _GRADIENT_ACCUMULATION_STEPS @staticmethod def get_global_batch_size() -> int: return StepTracker.get_local_batch_size() * ProcessGroupManager.get_data_parallel_world_size() + + @staticmethod + def get_gradient_accumulation_steps() -> int: + return _GRADIENT_ACCUMULATION_STEPS