diff --git a/configs/ultra-long-context-length/3b-131k-training.yml b/configs/ultra-long-context-length/3b-131k-training.yml index 57656352..b277447f 100644 --- a/configs/ultra-long-context-length/3b-131k-training.yml +++ b/configs/ultra-long-context-length/3b-131k-training.yml @@ -47,7 +47,7 @@ mixed_precision_args: dtype: bf16 distributed_args: - tensor_parallel_size: 8 + tensor_parallel_world_size: 8 fsdp_algorithm: 2 sequence_parallel: true tensor_parallel_word_embeddings: true diff --git a/configs/ultra-long-context-length/3b-65k-training.yml b/configs/ultra-long-context-length/3b-65k-training.yml index 76ffab51..b63beb0f 100644 --- a/configs/ultra-long-context-length/3b-65k-training.yml +++ b/configs/ultra-long-context-length/3b-65k-training.yml @@ -47,7 +47,7 @@ mixed_precision_args: dtype: bf16 distributed_args: - tensor_parallel_size: 8 + tensor_parallel_world_size: 8 fsdp_algorithm: 2 sequence_parallel: true tensor_parallel_word_embeddings: true diff --git a/configs/ultra-long-context-length/8b-65k-training.yml b/configs/ultra-long-context-length/8b-65k-training.yml index 84b7b494..1e87dcef 100644 --- a/configs/ultra-long-context-length/8b-65k-training.yml +++ b/configs/ultra-long-context-length/8b-65k-training.yml @@ -47,7 +47,7 @@ mixed_precision_args: dtype: bf16 distributed_args: - tensor_parallel_size: 8 + tensor_parallel_world_size: 8 fsdp_algorithm: 2 sequence_parallel: true tensor_parallel_word_embeddings: true diff --git a/dolomite_engine/arguments.py b/dolomite_engine/arguments.py index 0a6180f4..1cac85b9 100644 --- a/dolomite_engine/arguments.py +++ b/dolomite_engine/arguments.py @@ -324,11 +324,13 @@ class DistributedArgs(BaseArgs): # whether to use a dispatching dataloader dispatching_dataloader: bool = False # tensor parallel world size - tensor_parallel_size: int = 1 + tensor_parallel_world_size: int = 1 # tensor parallel embeddings tensor_parallel_word_embeddings: bool = False # whether to use sequence parallel sequence_parallel: bool = False + # pipeline parallel world size + pipeline_parallel_world_size: int = 1 # data parallel world size data_parallel_size: int | None = None # distributed timeout for NCCL in minutes @@ -337,6 +339,10 @@ class DistributedArgs(BaseArgs): fsdp_algorithm: int = 1 # whether to sync every gradient accumulation step sync_every_gradient_accumulation_step: bool = False + # total number of pipeline stages + num_pipeline_stages: int = 1 + # pipeline parallel shedule to use + pipeline_parallel_schedule: str | None = None # whether to use async-TP use_async_tensor_parallel: bool = False @@ -346,14 +352,14 @@ def model_post_init(self, __context: Any) -> None: self.communication_dtype = normalize_dtype_string(self.communication_dtype) if self.sequence_parallel: - assert self.tensor_parallel_size > 1, "tensor parallel needs to be enabled for sequence parallel" + assert self.tensor_parallel_world_size > 1, "tensor parallel needs to be enabled for sequence parallel" if self.tensor_parallel_word_embeddings: assert ( - self.tensor_parallel_size > 1 + self.tensor_parallel_world_size > 1 ), "tensor parallel needs to be enabled when using tensor parallel work embeddings" - if self.tensor_parallel_size > 1: + if self.tensor_parallel_world_size > 1: version = Version(torch.__version__).release version = [str(i) for i in version] version = ".".join(version) @@ -369,6 +375,13 @@ def model_post_init(self, __context: Any) -> None: if self.use_async_tensor_parallel: assert self.sequence_parallel, "sequence parallel should be enabled for using async-TP" + assert ( + self.num_pipeline_stages % self.pipeline_parallel_world_size == 0 + ), "num_pipeline_stages should be a multiple of pipeline_parallel_world_size" + + if self.num_pipeline_stages > 1: + _check_not_None([(self.pipeline_parallel_schedule, "pipeline_parallel_schedule")]) + class AimArgs(BaseArgs): # aim repo, experiment logs are saved here @@ -491,6 +504,9 @@ def model_post_init(self, __context: Any) -> None: # datasets _check_datasets(self.datasets) + if self.distributed_args.num_pipeline_stages > 1 and self.training_parameters.eval_during_training: + raise NotImplementedError("evaluation is not supported with pipeline parallel") + class GenerationParameters(BaseArgs): # batch size diff --git a/dolomite_engine/checkpointing.py b/dolomite_engine/checkpointing.py index 59ca7948..ab220874 100644 --- a/dolomite_engine/checkpointing.py +++ b/dolomite_engine/checkpointing.py @@ -18,15 +18,17 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torch.distributed.checkpoint.stateful import Stateful from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR from .arguments import InferenceArgs, TrainingArgs, UnshardingArgs +from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer from .data import ResumableDataLoader from .enums import Mode from .hf_models import fix_unsharded_state_dict -from .model_wrapper import ModelWrapper, get_model -from .optimization import get_scheduler +from .model_wrapper import ModelWrapper, get_model_container +from .optimization import get_scheduler_container from .utils import ExperimentsTracker, ProcessGroupManager, load_yaml, log_rank_0, run_rank_n, string_to_torch_dtype @@ -35,11 +37,88 @@ _KILLSWITCH = "KILLSWITCH" +class _ModelSaver(Stateful): + def __init__(self, model_container: ModelContainer) -> None: + self.model_container = model_container + + def state_dict(self) -> dict: + state_dict = {} + + for model in self.model_container: + model_state_dict = get_model_state_dict(model) + if model.has_teacher_model(): + model_state_dict = self._filter_out_teacher_state_dict(model_state_dict) + + state_dict.update(model_state_dict) + + return state_dict + + def load_state_dict(self, state_dict: dict) -> None: + for model in self.model_container: + model_state_dict = get_model_state_dict(model) + set_model_state_dict( + model, model_state_dict=state_dict, options=StateDictOptions(strict=not model.has_teacher_model()) + ) + + for key in model_state_dict: + del state_dict[key] + + assert len(state_dict) == 0, "unused keys found in the state dict" + + def _filter_out_teacher_state_dict(self, state_dict: dict) -> dict: + result = {} + for key, value in state_dict.items(): + if not "teacher_model" in key: + result[key] = value + + return result + + +class _OptimizerSaver(Stateful): + def __init__(self, model_container: ModelContainer, optimizer_container: OptimizerContainer) -> None: + self.model_container = model_container + self.optimizer_container = optimizer_container + + def state_dict(self) -> dict: + state_dict = {} + + for model, optimizer in zip(self.model_container, self.optimizer_container): + optimizer_state_dict = get_optimizer_state_dict( + model, optimizer, options=StateDictOptions(flatten_optimizer_state_dict=True) + ) + state_dict.update(optimizer_state_dict) + + return state_dict + + def load_state_dict(self, state_dict: dict) -> None: + for model, optimizer in zip(self.model_container, self.optimizer_container): + set_optimizer_state_dict( + model, + optimizer, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + + +class _LRSchedulerSaver(Stateful): + def __init__(self, lr_scheduler_container: LRSchedulerContainer) -> None: + self.lr_scheduler_container = lr_scheduler_container + + def state_dict(self) -> dict: + return [lr_scheduler.state_dict() for lr_scheduler in self.lr_scheduler_container] + + def load_state_dict(self, state_dict: list[dict]) -> None: + assert len(self.lr_scheduler_container) == len(state_dict) + + for lr_scheduler, lr_scheduler_state_dict in zip(self.lr_scheduler_container, state_dict): + lr_scheduler.load_state_dict(lr_scheduler_state_dict) + + def save_checkpoint( args: TrainingArgs, - model: ModelWrapper, - optimizer: Optimizer | None, - lr_scheduler: LambdaLR | None, + model_container: ModelContainer, + optimizer_container: OptimizerContainer | None, + lr_scheduler_container: LRSchedulerContainer | None, train_dataloader: ResumableDataLoader, experiments_tracker: ExperimentsTracker, iteration: int, @@ -49,9 +128,9 @@ def save_checkpoint( Args: args (TrainingArgs): arguments for training - model (ModelWrapper): model to save - optimizer (Optimizer): optimizer to save - lr_scheduler (LambdaLR): learning rate scheduler to save + model_container (ModelContainer): models to save + optimizer_container (OptimizerContainer): optimizers to save + lr_scheduler_container (LRSchedulerContainer): learning rate schedulers to save train_dataloader (DataLoader): train dataloader to save experiments_tracker (ExperimentsTracker): experiment tracker to save iteration (int): current iteration @@ -61,35 +140,33 @@ def save_checkpoint( ValueError: if unexpected distributed backend is found """ - save_optimizer = args.save_args.save_optimizer - save_path = _get_base_path(args.save_args.save_path, iteration) os.makedirs(save_path, exist_ok=True) - model_state_dict = get_model_state_dict(model) - if model.has_teacher_model(): - model_state_dict = _filter_out_teacher_state_dict(model_state_dict) + dcp.save({"state": _ModelSaver(model_container)}, checkpoint_id=_get_model_path(save_path)) - dcp.save(model_state_dict, checkpoint_id=_get_model_path(save_path)) - - if save_optimizer: - if optimizer is None: + if args.save_args.save_optimizer: + if optimizer_container is None: log_rank_0( logging.WARN, - "optimizer is not passed to save_checkpoint but save_optimizer is set to True. " + "optimizer_container is not passed to save_checkpoint but save_optimizer is set to True. " "Therefore, the function will not save the optimizer", ) else: - # TODO add options=StateDictOptions(flatten_optimizer_state_dict=True)) - dcp.save(get_optimizer_state_dict(model, optimizer), checkpoint_id=_get_optimizer_path(save_path)) + dcp.save( + {"state": _OptimizerSaver(model_container, optimizer_container)}, + checkpoint_id=_get_optimizer_path(save_path), + ) - if lr_scheduler is None: + if lr_scheduler_container is None: log_rank_0( logging.WARN, - "lr_scheduler is not passed to save_checkpoint. " "Therefore, the function will not save the lr_scheduler", + "lr_scheduler_container is not passed to save_checkpoint. Therefore, the function will not save the lr_scheduler", ) else: - run_rank_n(torch.save)(lr_scheduler.state_dict(), _get_lr_scheduler_path(save_path)) + lr_scheduler_path = _get_lr_scheduler_path(save_path) + os.makedirs(os.path.dirname(lr_scheduler_path), exist_ok=True) + torch.save(_LRSchedulerSaver(lr_scheduler_container).state_dict(), _get_lr_scheduler_path(save_path)) rng_state = { "random_rng_state": random.getstate(), @@ -131,18 +208,18 @@ def save_checkpoint( def load_checkpoint_for_training( args: TrainingArgs, - model: ModelWrapper, - optimizer: Optimizer, - lr_scheduler: LambdaLR, + model_container: ModelContainer, + optimizer_container: OptimizerContainer, + lr_scheduler_container: LRSchedulerContainer, train_dataloader: ResumableDataLoader, ) -> tuple[int, dict, dict]: """load checkpoint for training Args: args (TrainingArgs): arguments for training - model (ModelWrapper): model to load - optimizer (Optimizer): optimizer to save - lr_scheduler (LambdaLR): learning rate scheduler to load + model_container (ModelContainer): models to save + optimizer_container (OptimizerContainer): optimizers to save + lr_scheduler_container (LRSchedulerContainer): learning rate schedulers to save train_dataloader (ResumableDataLoader): train dataloader to load Raises: @@ -156,7 +233,6 @@ def load_checkpoint_for_training( return load_optimizer = args.load_args.load_optimizer - load_lr_scheduler = args.load_args.load_lr_scheduler load_rng_state = args.load_args.load_rng_state load_dataloader_state = args.load_args.load_dataloader_state load_experiments_tracker_state = args.load_args.load_experiments_tracker_state @@ -172,31 +248,29 @@ def load_checkpoint_for_training( log_rank_0(logging.INFO, f"loading checkpoint saved at {load_path}") - has_teacher_model = model.has_teacher_model() - if has_teacher_model: - log_rank_0( - logging.WARN, - "the model will use non-strict loading of state dict during distillation, this has potential of incorrect behavior", - ) - - model_state_dict = get_model_state_dict(model) - dcp.load(model_state_dict, checkpoint_id=_get_model_path(load_path)) - set_model_state_dict(model, model_state_dict, options=StateDictOptions(strict=not has_teacher_model)) - del model_state_dict + # FIXME drop original_state_dict after https://github.com/pytorch/pytorch/pull/138575 is fixed + saver = _ModelSaver(model_container) + state_dict = {"state": saver.state_dict()} + original_state_dict = {"state": {key: value for key, value in state_dict["state"].items()}} + dcp.load(state_dict, checkpoint_id=_get_model_path(load_path)) + state_dict.update(original_state_dict) + saver.load_state_dict(state_dict["state"]) if load_optimizer: - # TODO add options=StateDictOptions(flatten_optimizer_state_dict=True)) - optimizer_state_dict = get_optimizer_state_dict(model, optimizer) - dcp.load(optimizer_state_dict, checkpoint_id=_get_optimizer_path(load_path)) - set_optimizer_state_dict(model, optimizer, optim_state_dict=optimizer_state_dict) - del optimizer_state_dict - - if load_lr_scheduler: + # FIXME drop original_state_dict after https://github.com/pytorch/pytorch/pull/138575 is fixed + saver = _OptimizerSaver(model_container, optimizer_container) + state_dict = {"state": saver.state_dict()} + original_state_dict = {"state": {key: value for key, value in state_dict["state"].items()}} + dcp.load(state_dict, checkpoint_id=_get_optimizer_path(load_path)) + state_dict.update(original_state_dict) + saver.load_state_dict(state_dict["state"]) + + if args.load_args.load_lr_scheduler: assert load_optimizer, "load_lr_scheduler requires loading of optimizer" - lr_scheduler.load_state_dict(torch.load(_get_lr_scheduler_path(load_path))) - else: - if args.load_args.resume_learning_rate: + _LRSchedulerSaver(lr_scheduler_container).load_state_dict(torch.load(_get_lr_scheduler_path(load_path))) + elif args.load_args.resume_learning_rate: + for optimizer, lr_scheduler in zip(optimizer_container, lr_scheduler_container): _resume_learning_rate( args, optimizer=optimizer, @@ -262,14 +336,21 @@ def load_checkpoint_for_inference( log_rank_0(logging.INFO, "overriding mixed precision args") args_from_checkpoint.mixed_precision_args = args.mixed_precision_args - checkpoint_tp_world_size = args_from_checkpoint.distributed_args.tensor_parallel_size + checkpoint_tp_world_size = args_from_checkpoint.distributed_args.tensor_parallel_world_size with ( torch.device("meta") if use_meta else torch.device(torch.cuda.current_device()), ProcessGroupManager.set_dummy_tensor_parallel_rank(0), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), ): - model = get_model(args_from_checkpoint, mode) + original_num_stages = args_from_checkpoint.distributed_args.num_pipeline_stages + args_from_checkpoint.distributed_args.num_pipeline_stages = 1 + + model = get_model_container(args_from_checkpoint, mode)[0] + + args_from_checkpoint.distributed_args.num_pipeline_stages = original_num_stages if use_meta: model = model.to_empty(device="cpu") @@ -282,9 +363,11 @@ def load_checkpoint_for_inference( no_dist=True, ) + state = state["state"] + if checkpoint_tp_world_size > 1: state = fix_unsharded_state_dict( - model.config, state, tensor_parallel_size=checkpoint_tp_world_size, prefix="model." + model.config, state, tensor_parallel_world_size=checkpoint_tp_world_size, prefix="model." ) was_compiled_model = args_from_checkpoint.distributed_args.torch_compile @@ -329,8 +412,8 @@ def _resume_learning_rate( # we create lr scheduler again here since optimizer is loaded from disk and lr scheduler is now out of sync # this helps to resume phase 2 - lr_scheduler_tmp = get_scheduler( - optimizer=optimizer, + lr_scheduler_tmp = get_scheduler_container( + optimizer_container=OptimizerContainer([optimizer]), num_warmup_steps=args.lr_scheduler_args.num_warmup_steps, num_constant_steps=args.lr_scheduler_args.num_constant_steps, num_decay_steps=args.lr_scheduler_args.num_decay_steps, @@ -339,7 +422,7 @@ def _resume_learning_rate( lr_decay_factor=args.lr_scheduler_args.lr_decay_factor, extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args, last_epoch=-1 if iteration is None else iteration - 1, - ) + )[0] for grp, lr_ in zip(optimizer.param_groups, initial_lr): grp["initial_lr"] = lr_ @@ -365,7 +448,7 @@ def _get_optimizer_path(path: str) -> str: def _get_lr_scheduler_path(path: str) -> str: - return os.path.join(path, "lr_scheduler.pt") + return os.path.join(path, "lr_scheduler", f"lr_scheduler-{ProcessGroupManager.get_global_rank()}.pt") def _get_dataloader_path(path: str) -> str: @@ -386,12 +469,3 @@ def _get_experiments_tracker_path(path: str) -> str: def _get_metadata_path(path: str) -> str: return os.path.join(path, "metadata.json") - - -def _filter_out_teacher_state_dict(state_dict: dict) -> dict: - result = {} - for key, value in state_dict.items(): - if not "teacher_model" in key: - result[key] = value - - return result diff --git a/dolomite_engine/containers.py b/dolomite_engine/containers.py new file mode 100644 index 00000000..da1ac606 --- /dev/null +++ b/dolomite_engine/containers.py @@ -0,0 +1,65 @@ +import logging + +import torch.nn as nn + +from .utils import log_rank_0 + + +class _Container: + def __init__(self, model_list: list[nn.Module]) -> None: + self.model_list = model_list + + def __iter__(self): + for model in self.model_list: + yield model + + def __getitem__(self, index: int) -> nn.Module: + return self.model_list[index] + + def __setitem__(self, index: int, model: nn.Module) -> None: + self.model_list[index] = model + + def __len__(self) -> int: + return len(self.model_list) + + def __str__(self): + return str(self.model_list) + + +class ModelContainer(_Container): + def train(self) -> "ModelContainer": + for model in self: + model.train() + + def eval(self) -> "ModelContainer": + for model in self: + model.eval() + + return self + + +class LRSchedulerContainer(_Container): + def step(self) -> None: + for lr_scheduler in self: + lr_scheduler.step() + + +class OptimizerContainer(LRSchedulerContainer): + def zero_grad(self) -> None: + for optimizer in self: + optimizer.zero_grad() + + +def log_model_optimizer_container(model_container: ModelContainer, optimizer_container: OptimizerContainer) -> None: + """print model and optimizer + + Args: + model_container (ModelContainer): container of models to print + optimizer_container (OptimizerContainer): container of optimizers to print + """ + + log_rank_0(logging.INFO, "------------------------ model & optimizer list ------------------------") + for model, optimizer in zip(model_container, optimizer_container): + log_rank_0(logging.INFO, model) + log_rank_0(logging.INFO, optimizer) + log_rank_0(logging.INFO, "-------------------- end of model & optimizer list ---------------------") diff --git a/dolomite_engine/data/megatron/__init__.py b/dolomite_engine/data/megatron/__init__.py index 07658be0..e9042f55 100644 --- a/dolomite_engine/data/megatron/__init__.py +++ b/dolomite_engine/data/megatron/__init__.py @@ -26,6 +26,8 @@ def get_megatron_gpt_dataloaders(args: TrainingArgs, tokenizer: AutoTokenizer, c assert args.datasets[0].output_format == OUTPUT_FORMAT micro_batch_size = args.training_parameters.micro_batch_size + gradient_accumulation_steps = args.training_parameters.gradient_accumulation_steps + num_pipeline_stages = args.distributed_args.num_pipeline_stages sequence_length = class_args.get("sequence_length") compile_helpers() @@ -57,16 +59,14 @@ def _get_source_broadcast_mapping() -> dict: is_built_on_rank = ProcessGroupManager.get_global_rank() == node_rank * num_ranks_per_node else: # only build dataloader on first rank of each TP group - is_built_on_rank = ( - ProcessGroupManager.get_global_rank() == ProcessGroupManager.get_tensor_parallel_first_rank() - ) + is_built_on_rank = ProcessGroupManager.is_tensor_parallel_first_rank() gpt_dataset_builder = BlendedMegatronDatasetBuilder( GPTDataset, sizes=_get_train_val_test_samples( args.training_parameters.num_training_steps, micro_batch_size, - args.training_parameters.gradient_accumulation_steps, + gradient_accumulation_steps, args.training_parameters.eval_interval, class_args.get("eval_steps"), ), @@ -168,7 +168,11 @@ def _get_dataloader(dataset: GPTDataset | None, consumed_samples: int): batch_sampler = MegatronBatchSampler( total_samples=len(dataset), consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size * num_ranks_per_node, + micro_batch_size=( + micro_batch_size * num_ranks_per_node + if num_pipeline_stages == 1 + else micro_batch_size * gradient_accumulation_steps * num_ranks_per_node + ), num_replicas=num_nodes, rank=node_rank, ) @@ -184,7 +188,10 @@ def _get_dataloader(dataset: GPTDataset | None, consumed_samples: int): pin_memory=True, source_broadcast_mapping=source_broadcast_mapping, broadcast_world_size=num_ranks_per_node, - static_shape_per_rank=(micro_batch_size, sequence_length + 1), + static_shape_per_rank=( + (micro_batch_size if num_pipeline_stages == 1 else micro_batch_size * gradient_accumulation_steps), + sequence_length + 1, + ), keys=["text"], ) else: @@ -194,7 +201,9 @@ def _get_dataloader(dataset: GPTDataset | None, consumed_samples: int): batch_sampler = MegatronBatchSampler( total_samples=len(dataset), consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, + micro_batch_size=( + micro_batch_size if num_pipeline_stages == 1 else micro_batch_size * gradient_accumulation_steps + ), num_replicas=ProcessGroupManager.get_data_parallel_world_size(), rank=ProcessGroupManager.get_data_parallel_rank(), ) diff --git a/dolomite_engine/distributed/__init__.py b/dolomite_engine/distributed/__init__.py index 28d7a2af..433d556f 100644 --- a/dolomite_engine/distributed/__init__.py +++ b/dolomite_engine/distributed/__init__.py @@ -1,5 +1,6 @@ import logging from functools import partial +from typing import Callable import torch import torch.nn as nn @@ -11,12 +12,18 @@ from torch.distributed.fsdp import MixedPrecision as MixedPrecision1 from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + PipelineScheduleMulti, + PipelineScheduleSingle, + _PipelineSchedule, + get_schedule_class, +) from ..arguments import TrainingArgs +from ..containers import ModelContainer from ..enums import FP8Backend from ..gradient_checkpointing import apply_gradient_checkpointing -from ..model_wrapper import ModelWrapper -from ..optimization import get_optimizer, get_scheduler from ..utils import ProcessGroupManager, get_module_class_from_name, log_rank_0, string_to_torch_dtype from .fp8 import convert_model_to_transformer_engine @@ -32,15 +39,17 @@ } -def wrap_model_for_distributed_training(args: TrainingArgs, model: ModelWrapper) -> ModelWrapper: +def wrap_model_container_for_distributed_training( + args: TrainingArgs, model_container: ModelContainer +) -> tuple[ModelContainer, _PipelineSchedule]: """converts the model to a ZeRO-DP sharded model Args: args (TrainingArgs): arguments based on training mode - model (ModelWrapper): any nn.Module object + model_container (ModelContainer): model container Returns: - ModelWrapper: parallelized model + tuple[ModelContainer, _PipelineSchedule]: container of parallelized models and pipeline schedule """ stage = args.distributed_args.stage @@ -51,6 +60,7 @@ def wrap_model_for_distributed_training(args: TrainingArgs, model: ModelWrapper) fp8_backend = args.mixed_precision_args.fp8_backend efficient_initialization = args.model_args.efficient_initialization fsdp_algorithm = args.distributed_args.fsdp_algorithm + num_pipeline_stages = args.distributed_args.num_pipeline_stages if dtype in ["fp16", "bf16"]: if communication_dtype != "fp32": @@ -60,11 +70,14 @@ def wrap_model_for_distributed_training(args: TrainingArgs, model: ModelWrapper) ) if dtype == "fp8" and fp8_backend == FP8Backend.nvte: + # FIXME this wont work convert_model_to_transformer_engine(model) dtype = "bf16" - block_names = model.model._no_split_modules - teacher_block_names = model.teacher_model._no_split_modules if model.has_teacher_model() else [] + block_names = model_container[0].model._no_split_modules + teacher_block_names = ( + model_container[0].teacher_model._no_split_modules if model_container[0].has_teacher_model() else [] + ) dtype = None if dtype is None else string_to_torch_dtype(dtype) communication_dtype = None if communication_dtype is None else string_to_torch_dtype(communication_dtype) @@ -72,17 +85,20 @@ def wrap_model_for_distributed_training(args: TrainingArgs, model: ModelWrapper) assert stage in [0, 2, 3] dp_mesh = ProcessGroupManager.get_data_parallel_mesh() - block_classes = [get_module_class_from_name(model, name) for name in block_names + teacher_block_names] + block_classes = [ + get_module_class_from_name(model_container[0], name) for name in block_names + teacher_block_names + ] if args.distributed_args.gradient_checkpointing_method is not None: assert len(block_names) == 1 - apply_gradient_checkpointing( - model, - args.distributed_args.gradient_checkpointing_method, - block_name=block_names[0], - **args.distributed_args.gradient_checkpointing_args, - ) + for model in model_container: + apply_gradient_checkpointing( + model, + args.distributed_args.gradient_checkpointing_method, + block_name=block_names[0], + **args.distributed_args.gradient_checkpointing_args, + ) if fsdp_algorithm == 1: if stage == 0: @@ -113,44 +129,46 @@ def _param_init(module: nn.Module) -> None: if efficient_initialization and ProcessGroupManager.get_data_parallel_rank() != 0: module = module.to_empty(device=torch.cuda.current_device()) - model = FSDP( - model, - sharding_strategy=sharding_strategy, - cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None, - mixed_precision=_get_fsdp_mixed_precision( - dtype=dtype, - communication_dtype=communication_dtype, - fsdp_algorithm=1, - ), - auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=block_classes), - device_id=torch.cuda.current_device(), - limit_all_gathers=True, - use_orig_params=True, - # https://github.com/meta-llama/llama-recipes/blob/492455dc080f6c25f356e283e443be0cce86aaeb/src/llama_recipes/finetuning.py#L191 - sync_module_states=efficient_initialization, - param_init_fn=_param_init if efficient_initialization else None, - device_mesh=dp_mesh, - ) - else: - if stage == 0: - log_rank_0(logging.INFO, "using DDP") - - assert not efficient_initialization - - model = FSDP( + for i, model in enumerate(model_container): + model_container[i] = FSDP( model, - sharding_strategy=ShardingStrategy.NO_SHARD, + sharding_strategy=sharding_strategy, cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None, mixed_precision=_get_fsdp_mixed_precision( dtype=dtype, communication_dtype=communication_dtype, fsdp_algorithm=1, ), + auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=block_classes), device_id=torch.cuda.current_device(), limit_all_gathers=True, use_orig_params=True, + # https://github.com/meta-llama/llama-recipes/blob/492455dc080f6c25f356e283e443be0cce86aaeb/src/llama_recipes/finetuning.py#L191 + sync_module_states=efficient_initialization, + param_init_fn=_param_init if efficient_initialization else None, device_mesh=dp_mesh, ) + else: + if stage == 0: + log_rank_0(logging.INFO, "using DDP") + + assert not efficient_initialization + + for i, model in enumerate(model_container): + model_container[i] = FSDP( + model, + sharding_strategy=ShardingStrategy.NO_SHARD, + cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None, + mixed_precision=_get_fsdp_mixed_precision( + dtype=dtype, + communication_dtype=communication_dtype, + fsdp_algorithm=1, + ), + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + use_orig_params=True, + device_mesh=dp_mesh, + ) else: log_rank_0(logging.INFO, "using FSDP-2") @@ -162,36 +180,109 @@ def _param_init(module: nn.Module) -> None: zero3 = stage == 3 - for module in model.modules(): - if isinstance(module, tuple(block_classes)): - fully_shard( - module, - mesh=dp_mesh, - reshard_after_forward=zero3, - mp_policy=mixed_precision_policy, - offload_policy=CPUOffloadPolicy(pin_memory=True) if cpu_offload else OffloadPolicy(), - ) - - fully_shard( + for i, model in enumerate(model_container): + for module in model.modules(): + if isinstance(module, tuple(block_classes)): + fully_shard( + module, + mesh=dp_mesh, + reshard_after_forward=zero3, + mp_policy=mixed_precision_policy, + offload_policy=CPUOffloadPolicy(pin_memory=True) if cpu_offload else OffloadPolicy(), + ) + + fully_shard( + model, + mesh=dp_mesh, + reshard_after_forward=zero3, + mp_policy=mixed_precision_policy, + offload_policy=CPUOffloadPolicy(pin_memory=True) if cpu_offload else OffloadPolicy(), + ) + + if efficient_initialization and args.model_args.model_name is None: + model = model.to_empty(device=torch.cuda.current_device()) + + for module in model.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + if torch_compile: + log_rank_0(logging.INFO, "using torch compile") + + for i in range(len(model_container)): + model_container[i] = torch.compile(model_container[i]) + + pipeline_stages = [] + pipeline_schedule = None + + if num_pipeline_stages > 1: + micro_batch_size = args.training_parameters.micro_batch_size + sequence_length = args.datasets[0].class_args.get("sequence_length") + + for model in model_container: + intermediate_dtype = string_to_torch_dtype(args.mixed_precision_args.dtype) + + dummy_input_tensor = model.model.get_dummy_input_tensor( + micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype + ) + dummy_output_tensor = model.model.get_dummy_output_tensor( + micro_batch_size, + sequence_length, + intermediate_dtype=intermediate_dtype, + output_parallel_lm_logits_if_possible=True, + ) + + stage = PipelineStage( model, - mesh=dp_mesh, - reshard_after_forward=zero3, - mp_policy=mixed_precision_policy, - offload_policy=CPUOffloadPolicy(pin_memory=True) if cpu_offload else OffloadPolicy(), + stage_index=model.pipeline_stage_id, + num_stages=num_pipeline_stages, + device=torch.cuda.current_device(), + input_args=dummy_input_tensor, + output_args=dummy_output_tensor, + group=ProcessGroupManager.get_pipeline_parallel_group(), ) + pipeline_stages.append(stage) - if efficient_initialization and args.model_args.model_name is None: - model = model.to_empty(device=torch.cuda.current_device()) + pipeline_schedule = _get_pipeline_parallel_schedule( + pipeline_parallel_schedule=args.distributed_args.pipeline_parallel_schedule, + gradient_accumulation_steps=args.training_parameters.gradient_accumulation_steps, + pipeline_stages=pipeline_stages, + loss_fn=model.get_loss, + ) - for module in model.modules(): - if hasattr(module, "reset_parameters"): - module.reset_parameters() + return model_container, pipeline_schedule + + +def _get_pipeline_parallel_schedule( + pipeline_parallel_schedule: str, + gradient_accumulation_steps: int, + pipeline_stages: list[PipelineStage], + loss_fn: Callable, +) -> _PipelineSchedule: + try: + schedule_class = get_schedule_class(pipeline_parallel_schedule) + except ValueError: + raise ValueError( + f"unexpected schedule ({pipeline_parallel_schedule}), expected values are: ['1F1B', " + "'Interleaved1F1B', 'GPipe', 'FlexibleInterleaved1F1B', 'LoopedBFS', 'InterleavedZeroBubble', " + "'PipelineScheduleSingle', 'PipelineScheduleMulti']" + ) - if torch_compile: - log_rank_0(logging.INFO, "using torch compile") - model = torch.compile(model) + if schedule_class in [PipelineScheduleSingle, PipelineScheduleMulti]: + raise NotImplementedError() + + if issubclass(schedule_class, PipelineScheduleSingle): + assert len(pipeline_stages) == 1 + + def custom_loss_function(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + loss_dict = loss_fn(output, target) + return loss_dict["loss"] - return model + return schedule_class( + pipeline_stages if issubclass(schedule_class, PipelineScheduleMulti) else pipeline_stages[0], + n_microbatches=gradient_accumulation_steps, + loss_fn=custom_loss_function, + ) def _get_fsdp_mixed_precision( diff --git a/dolomite_engine/finetune.py b/dolomite_engine/finetune.py index fd3ab063..58d879d6 100644 --- a/dolomite_engine/finetune.py +++ b/dolomite_engine/finetune.py @@ -3,6 +3,7 @@ import torch from torch.distributed._tensor.api import DTensor +from torch.distributed.pipelining.schedules import _PipelineSchedule from torch.distributed.tensor.parallel import loss_parallel from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR @@ -10,11 +11,12 @@ from .arguments import TrainingArgs, get_args from .checkpointing import load_checkpoint_for_training, save_checkpoint +from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container from .data import ResumableDataLoader, custom_iterator, get_dataloader, get_next_batch -from .distributed import wrap_model_for_distributed_training +from .distributed import wrap_model_container_for_distributed_training from .enums import DatasetSplit, FP8Backend, Mode, TuningMethod -from .model_wrapper import ModelWrapperForFinetuning, get_model, log_model -from .optimization import get_optimizer, get_scheduler, log_optimizer +from .model_wrapper import ModelWrapperForFinetuning, get_model_container +from .optimization import get_optimizer_container, get_scheduler_container from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics, train_step from .utils import ( ExperimentsTracker, @@ -33,9 +35,10 @@ def train( args: TrainingArgs, - model: ModelWrapperForFinetuning, - optimizer: Optimizer, - lr_scheduler: LambdaLR, + model_container: ModelContainer, + pipeline_schedule: _PipelineSchedule, + optimizer_container: OptimizerContainer, + lr_scheduler_container: LRSchedulerContainer, train_dataloader: ResumableDataLoader, val_dataloader: ResumableDataLoader, experiments_tracker: ExperimentsTracker, @@ -45,9 +48,10 @@ def train( Args: args (TrainingArgs): training args - model (ModelWrapperForFinetuning): model - optimizer (Optimizer): optimizer - lr_scheduler (LRScheduler): learning rate scheduler + model_container (ModelContainer): container of models + pipeline_schedule (_PipelineSchedule): pipeline schedule + optimizer_container (OptimizerContainer): container of optimizers + lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers train_dataloader (ResumableDataLoader): training dataloader val_dataloader (ResumableDataLoader): validation dataloader experiments_tracker (ExperimentsTracker): metrics tracker @@ -63,13 +67,13 @@ def train( save_interval = args.save_args.save_interval log_interval = args.logging_args.log_interval - model.train() + model_container.train() # need this for iterating infinitely train_dataloader_infinite = custom_iterator(train_dataloader, infinite=True) if eval_during_training: - evaluate(val_dataloader, model, starting_iteration, experiments_tracker) + evaluate(val_dataloader, model_container, starting_iteration, experiments_tracker) forward_context = ( partial( @@ -95,15 +99,19 @@ def train( global_step += 1 loss_step_dict = train_step( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, + model_container=model_container, + pipeline_schedule=pipeline_schedule, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, train_dataloader=train_dataloader_infinite, gradient_accumulation_steps=gradient_accumulation_steps, gradient_clipping=gradient_clipping, forward_context=forward_context, backward_context=backward_context, sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, + is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1, + batch_size=None, + sequence_length=None, ) metrics_tracker = metrics_tracker + loss_step_dict @@ -113,7 +121,7 @@ def train( if global_step % log_interval == 0: metrics_tracker = metrics_tracker / log_interval - metrics_tracker["learning_rate"] = lr_scheduler.get_lr()[0] + metrics_tracker["learning_rate"] = lr_scheduler_container[0].get_lr()[0] track_metrics( global_step=global_step, @@ -125,10 +133,18 @@ def train( metrics_tracker = MetricsTrackingDict({}) if eval_during_training and (global_step % eval_interval == 0 or global_step == num_training_steps): - evaluate(val_dataloader, model, global_step, experiments_tracker) + evaluate(val_dataloader, model_container, global_step, experiments_tracker) if global_step % save_interval == 0 or global_step == num_training_steps: - save_checkpoint(args, model, optimizer, lr_scheduler, train_dataloader, experiments_tracker, global_step) + save_checkpoint( + args=args, + model_container=model_container, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, + train_dataloader=train_dataloader, + experiments_tracker=experiments_tracker, + iteration=global_step, + ) if torch_profiler is not None: torch_profiler.__exit__() @@ -137,7 +153,7 @@ def train( @torch.no_grad() def evaluate( val_dataloader: ResumableDataLoader, - model: ModelWrapperForFinetuning, + model_container: ModelContainer, global_step: int, experiments_tracker: ExperimentsTracker, ) -> MetricsTrackingDict: @@ -145,7 +161,7 @@ def evaluate( Args: val_dataloader (ResumableDataLoader): validation dataloader - model (ModelWrapperForFinetuning): model + model_container (ModelContainer): model container global_step (int): global step during training experiments_tracker (ExperimentsTracker): metrics tracker @@ -154,7 +170,7 @@ def evaluate( """ if ProcessGroupManager.is_tensor_parallel_enabled(): - if ProcessGroupManager.get_tensor_parallel_rank() == 0: + if ProcessGroupManager.is_tensor_parallel_first_rank(): num_steps = 0 if val_dataloader is None else len(val_dataloader) else: num_steps = 0 @@ -168,14 +184,14 @@ def evaluate( if num_steps == 0: return - model.eval() + model_container.eval() metrics_tracker = MetricsTrackingDict({}) val_dataloader = custom_iterator(val_dataloader, infinite=False) for _ in range(num_steps): batch = get_next_batch(val_dataloader) - loss_step_dict = model(batch) + loss_step_dict = model_container[0](batch) metrics_tracker = metrics_tracker + loss_step_dict metrics_tracker = metrics_tracker / num_steps @@ -193,7 +209,7 @@ def evaluate( context="val", ) - model.train() + model_container.train() return metrics_tracker @@ -215,7 +231,8 @@ def main() -> None: # initialize distributed with nccl for multi-node communications init_distributed( - tensor_parallel_size=args.distributed_args.tensor_parallel_size, + tensor_parallel_world_size=args.distributed_args.tensor_parallel_world_size, + pipeline_parallel_world_size=args.distributed_args.pipeline_parallel_world_size, data_parallel_size=args.distributed_args.data_parallel_size, data_parallel_replication_world_size=args.distributed_args.zero_topology.data_parallel_replication_world_size, data_parallel_sharding_world_size=args.distributed_args.zero_topology.data_parallel_sharding_world_size, @@ -223,16 +240,19 @@ def main() -> None: timeout_minutes=args.distributed_args.timeout_minutes, use_async_tensor_parallel=args.distributed_args.use_async_tensor_parallel, ) + set_seed(args.random_args.seed) - model = get_model(args, mode) + assert args.distributed_args.num_pipeline_stages == 1, "pipeline parallel is not supported with finetuning" + + model_container = get_model_container(args, mode) train_dataloader = get_dataloader( args, split=DatasetSplit.train, mode=mode, - tokenizer=model.tokenizer, - is_encoder_decoder=model.is_encoder_decoder, + tokenizer=model_container[0].tokenizer, + is_encoder_decoder=model_container[0].is_encoder_decoder, ) val_dataloader = None @@ -241,21 +261,21 @@ def main() -> None: args, split=DatasetSplit.val, mode=mode, - tokenizer=model.tokenizer, - is_encoder_decoder=model.is_encoder_decoder, + tokenizer=model_container[0].tokenizer, + is_encoder_decoder=model_container[0].is_encoder_decoder, ) - model = wrap_model_for_distributed_training(args, model) + model_container, pipeline_schedule = wrap_model_container_for_distributed_training(args, model_container) - optimizer = get_optimizer( + optimizer_container = get_optimizer_container( optimizer_class_name=args.optimizer_args.class_name, optimizer_class_args=args.optimizer_args.class_args, - model=model, + model_container=model_container, params_group_method=args.optimizer_args.params_group_method, ) - lr_scheduler = get_scheduler( - optimizer=optimizer, + lr_scheduler_container = get_scheduler_container( + optimizer_container=optimizer_container, num_warmup_steps=args.lr_scheduler_args.num_warmup_steps, num_constant_steps=args.lr_scheduler_args.num_constant_steps, num_decay_steps=args.lr_scheduler_args.num_decay_steps, @@ -265,14 +285,13 @@ def main() -> None: extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args, ) - log_model(model) - log_optimizer(optimizer) + log_model_optimizer_container(model_container, optimizer_container) starting_iteration = 0 experiments_tracker_state_dict = None if args.load_args is not None: starting_iteration, _, experiments_tracker_state_dict = load_checkpoint_for_training( - args, model, optimizer, lr_scheduler, train_dataloader + args, model_container, optimizer_container, lr_scheduler_container, train_dataloader ) experiments_tracker = ExperimentsTracker( @@ -287,9 +306,10 @@ def main() -> None: # main training loop train( args, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, + model_container=model_container, + pipeline_schedule=pipeline_schedule, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, train_dataloader=train_dataloader, val_dataloader=val_dataloader, experiments_tracker=experiments_tracker, diff --git a/dolomite_engine/generate.py b/dolomite_engine/generate.py index 38408bfa..9f2274ae 100644 --- a/dolomite_engine/generate.py +++ b/dolomite_engine/generate.py @@ -86,6 +86,8 @@ def main() -> None: torch.device(torch.cuda.current_device()), ProcessGroupManager.set_dummy_tensor_parallel_rank(0), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), ): model = ModelWrapperForFinetuning( mode=mode, diff --git a/dolomite_engine/hf_models/__init__.py b/dolomite_engine/hf_models/__init__.py index feadfa40..4ff9941d 100644 --- a/dolomite_engine/hf_models/__init__.py +++ b/dolomite_engine/hf_models/__init__.py @@ -19,7 +19,7 @@ RNNDolomiteModel, convert_gpt_dolomite_to_gpt_crosslayer, ) -from .register_hf import get_tensor_parallel_class, is_custom_model, register_model_classes +from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts from .utils import convert_padding_free_lists_to_tensors diff --git a/dolomite_engine/hf_models/mixins/dense/base.py b/dolomite_engine/hf_models/mixins/dense/base.py index 71007435..6be1ad6f 100644 --- a/dolomite_engine/hf_models/mixins/dense/base.py +++ b/dolomite_engine/hf_models/mixins/dense/base.py @@ -5,11 +5,12 @@ from transformers import DynamicCache, PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast +from ....utils import divide_if_divisible from ...config import CommonConfig from ...defaults import DEFAULT_NORMALIZATION_IMPLEMENTATION from ...enums import AttentionHeadType, PositionEmbeddingType from ...modeling_utils import Alibi, ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function -from ...utils import convert_padding_free_lists_to_tensors, divide_if_divisible +from ...utils import convert_padding_free_lists_to_tensors class PreTrainedModelMixin(PreTrainedModel): diff --git a/dolomite_engine/hf_models/mixins/dense_TP/base.py b/dolomite_engine/hf_models/mixins/dense_TP/base.py index 87028278..8a59eff1 100644 --- a/dolomite_engine/hf_models/mixins/dense_TP/base.py +++ b/dolomite_engine/hf_models/mixins/dense_TP/base.py @@ -1,6 +1,9 @@ +import torch import torch.nn as nn +from transformers import DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPast -from ....utils import ProcessGroupManager +from ....utils import ProcessGroupManager, divide_if_divisible from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType from ...modeling_utils import RoPE, YaRNScaledRoPE @@ -9,12 +12,22 @@ class PreTrainedModelMixin_TP(PreTrainedModelMixin): - def __init__(self, config: CommonConfig, *args, **kwargs): + def __init__(self, config: CommonConfig, *args, **kwargs) -> None: self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) self.sequence_parallel = kwargs.get("sequence_parallel", False) + self.num_pipeline_stages = kwargs.get("num_pipeline_stages", 1) + self.pipeline_stage_id = kwargs.get("pipeline_stage_id", 0) + + self.is_first_stage = self.pipeline_stage_id == 0 + self.is_last_stage = self.pipeline_stage_id == self.num_pipeline_stages - 1 + self.is_pipeline_parallel_enabled = self.num_pipeline_stages > 1 + super().__init__(config, *args, **kwargs) + if self.is_pipeline_parallel_enabled and self._tied_word_embeddings: + raise NotImplementedError() + class BaseModelMixin_TP(PreTrainedModelMixin_TP, BaseModelMixin): def _init_model(self, config: CommonConfig, **kwargs) -> None: @@ -26,27 +39,36 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.initializer_range = config.initializer_range self.head_dim = self.embed_dim // self.num_heads - self.wte = Embedding_TP( - config.vocab_size, - self.embed_dim, - std=self.initializer_range, - tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, - use_padding_free_transformer=self._use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, + self.layers_per_stage = divide_if_divisible( + config.n_layer, self.num_pipeline_stages, "layers should be divisible by num_pipeline_stages" ) - self.drop = ( - nn.Identity() - if config.embd_pdrop == 0 - else Dropout_TP( - config.embd_pdrop, + self.layer_start_id = self.layers_per_stage * self.pipeline_stage_id + self.layer_end_id = self.layers_per_stage * (self.pipeline_stage_id + 1) + + if self.is_first_stage: + self.wte = Embedding_TP( + config.vocab_size, + self.embed_dim, + std=self.initializer_range, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, use_padding_free_transformer=self._use_padding_free_transformer, sequence_parallel=self.sequence_parallel, ) - ) - self.h = nn.ModuleList( - [ - self.layer_class( + + self.drop = ( + nn.Identity() + if config.embd_pdrop == 0 + else Dropout_TP( + config.embd_pdrop, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + ) + + self.h = nn.ModuleDict( + { + str(i): self.layer_class( config, normalization_implementation=self.normalization_implementation, attention_implementation=self.attention_implementation, @@ -54,37 +76,133 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: layer_idx=i, sequence_parallel=self.sequence_parallel, ) - for i in range(config.num_hidden_layers) - ] - ) - self.ln_f = get_normalization_function_TP( - config.normalization_function, - self.embed_dim, - eps=config.layer_norm_epsilon, - normalization_implementation=self.normalization_implementation, - use_padding_free_transformer=self._use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, + for i in range(self.layer_start_id, self.layer_end_id) + } ) + if self.is_last_stage: + self.ln_f = get_normalization_function_TP( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) self._setup_positional_encoding() # Initialize weights and apply final processing self.post_init() + def forward( + self, + input_ids: torch.Tensor | None = None, + past_key_values: DynamicCache | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + ) -> tuple | BaseModelOutputWithPast: + if self.is_first_stage: + ( + output_hidden_states, + use_cache, + hidden_states, + attention_mask, + position_ids, + rope_cos_sin, + past_key_values, + ) = self._prepare_a_bunch_of_stuff( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + assert past_key_values is None + assert attention_mask is None + + hidden_states = input_ids + past_length = 0 + + if self._use_padding_free_transformer: + key_length = max_seqlen + # query length will change if past_key_values is not None + query_length = key_length - past_length + else: + key_length = ( + hidden_states.size(1) * ProcessGroupManager.get_tensor_parallel_world_size() + if self.sequence_parallel + else hidden_states.size(1) + ) + query_length = key_length - past_length + + position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=hidden_states.device) + position_ids = position_ids.unsqueeze(0).view(-1, query_length) + + rope_cos_sin = self._get_rope_cos_sin( + key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device + ) + + past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values + all_hidden_states = () if output_hidden_states else None + + for layer_idx in range(self.layer_start_id, self.layer_end_id): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.h[str(layer_idx)]( + hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + rope_cos_sin=rope_cos_sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + if self.is_last_stage: + hidden_states = self.ln_f(hidden_states) + + # Add last hidden state + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - self.wpe = Embedding_TP( - max_position_embeddings, - self.embed_dim, - std=self.initializer_range, - tensor_parallel_word_embeddings=False, - use_padding_free_transformer=self._use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, - ) + if self.is_first_stage: + self.wpe = Embedding_TP( + max_position_embeddings, + self.embed_dim, + std=self.initializer_range, + tensor_parallel_word_embeddings=False, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) elif self.position_embedding_type == PositionEmbeddingType.alibi: + if self.is_pipeline_parallel_enabled: + raise NotImplementedError() + self.alibi = Alibi_TP(self.num_heads) elif self.position_embedding_type == PositionEmbeddingType.rope: if self.config.rope_scaling is None: diff --git a/dolomite_engine/hf_models/mixins/dense_TP/main.py b/dolomite_engine/hf_models/mixins/dense_TP/main.py index 4505921c..2e05a70c 100644 --- a/dolomite_engine/hf_models/mixins/dense_TP/main.py +++ b/dolomite_engine/hf_models/mixins/dense_TP/main.py @@ -9,7 +9,7 @@ from transformers import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ....utils import ProcessGroupManager, SafeTensorsWeightsManager +from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible from ...config import CommonConfig from ...enums import PositionEmbeddingType from ...modeling_utils_TP import LMHead_TP, dtensor_to_tensor, tensor_to_dtensor @@ -18,23 +18,24 @@ class CausalLMModelMixin_TP(PreTrainedModelMixin_TP, CausalLMModelMixin): - tensor_parallel_state_dict_function = None + model_parallel_state_dict_function = None def _init_model(self, config: CommonConfig, **kwargs) -> None: self.vocab_size = config.vocab_size self.transformer = self.base_model_class(config, **kwargs) - if not self._tied_word_embeddings: - self.lm_head = LMHead_TP( - self.vocab_size, - config.n_embd, - std=config.initializer_range, - tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, - sequence_parallel=self.sequence_parallel, - ) + if self.is_last_stage: + if not self._tied_word_embeddings: + self.lm_head = LMHead_TP( + self.vocab_size, + config.n_embd, + std=config.initializer_range, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + sequence_parallel=self.sequence_parallel, + ) - self.m_width = config.m_width - self.upcast_logits_for_loss = config.upcast_logits_for_loss + self.m_width = config.m_width + self.upcast_logits_for_loss = config.upcast_logits_for_loss self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() @@ -57,20 +58,21 @@ def forward( output_parallel_lm_logits: bool = False, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, - ) -> tuple | CausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + ) -> CausalLMOutputWithPast | torch.Tensor: + if not self.is_pipeline_parallel_enabled or self.is_first_stage: + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids, @@ -85,28 +87,38 @@ def forward( max_seqlen=max_seqlen, ) - lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) - - if self.m_width is not None: - lm_logits = lm_logits / self.m_width - - loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) - - if output_parallel_lm_logits: - assert self.tensor_parallel_word_embeddings + if not self.is_pipeline_parallel_enabled or self.is_last_stage: + lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + + if not self.is_pipeline_parallel_enabled: + loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + + if not self.is_pipeline_parallel_enabled or self.is_last_stage: + if output_parallel_lm_logits: + assert self.tensor_parallel_word_embeddings + else: + if self.tensor_parallel_word_embeddings: + # all gather + lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + + if not self.is_pipeline_parallel_enabled: + output = CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + elif self.is_last_stage: + output = lm_logits else: - if self.tensor_parallel_word_embeddings: - # all gather - lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) + output = transformer_outputs.last_hidden_state + + return output def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: return ( @@ -187,10 +199,90 @@ def load_from_safetensors_weights_manager(self, safetensors_weights_manager: Saf elif position_embedding_type == PositionEmbeddingType.rope: self.transformer.rope.reset_parameters() - state_dict = self.__class__.tensor_parallel_state_dict_function( + state_dict = self.__class__.model_parallel_state_dict_function( config=self.config, safetensors_weights_manager=safetensors_weights_manager, tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + num_pipeline_stages=self.num_pipeline_stages, + pipeline_stage_id=self.pipeline_stage_id, ) self.load_state_dict(state_dict) + + def get_dummy_input_tensor( + self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype + ) -> tuple[int]: + if self.is_first_stage: + # 1 is added to sequence length since megatron's dataloader gives an extra token and for good reason + tensor = torch.empty( + micro_batch_size, sequence_length + 1, device=torch.cuda.current_device(), dtype=torch.long + ) + else: + tensor = self._get_dummy_intermediate_tensor( + micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype + ) + + return tensor + + def get_dummy_output_tensor( + self, + micro_batch_size: int, + sequence_length: int, + intermediate_dtype: torch.dtype, + output_parallel_lm_logits_if_possible: bool, + ) -> tuple[int]: + if self.is_last_stage: + vocab_size = self.config.vocab_size + if self.tensor_parallel_word_embeddings and output_parallel_lm_logits_if_possible: + vocab_size = divide_if_divisible(vocab_size, ProcessGroupManager.get_tensor_parallel_world_size(), "") + + if self._use_padding_free_transformer: + tensor = torch.empty( + micro_batch_size * sequence_length, + vocab_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = torch.empty( + micro_batch_size, + sequence_length, + vocab_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = self._get_dummy_intermediate_tensor( + micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype + ) + + return tensor + + def _get_dummy_intermediate_tensor( + self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype + ) -> tuple[int]: + sharded_sequence_length = ( + divide_if_divisible(sequence_length, ProcessGroupManager.get_tensor_parallel_world_size(), "") + if self.sequence_parallel + else sequence_length + ) + + hidden_size = self.config.hidden_size + + if self._use_padding_free_transformer: + tensor = torch.empty( + micro_batch_size * sharded_sequence_length, + hidden_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = torch.empty( + micro_batch_size, + sharded_sequence_length, + hidden_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + + return tensor diff --git a/dolomite_engine/hf_models/modeling_utils/attention/base.py b/dolomite_engine/hf_models/modeling_utils/attention/base.py index 51903f15..d37fb217 100644 --- a/dolomite_engine/hf_models/modeling_utils/attention/base.py +++ b/dolomite_engine/hf_models/modeling_utils/attention/base.py @@ -5,9 +5,9 @@ import torch.nn.functional as F from transformers import DynamicCache +from ....utils import divide_if_divisible from ...config import CommonConfig from ...enums import AttentionHeadType, InitMethod, PositionEmbeddingType -from ...utils import divide_if_divisible from ..linear import ParameterizedLinear from ..position_embedding import apply_rotary_pos_emb from .utils import repeat_key_value diff --git a/dolomite_engine/hf_models/modeling_utils_TP/TP.py b/dolomite_engine/hf_models/modeling_utils_TP/TP.py index 129a012b..d575f098 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/TP.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/TP.py @@ -6,8 +6,7 @@ from torch.distributed._tensor.placement_types import Placement, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh -from ...utils import ProcessGroupManager -from ..utils import divide_if_divisible +from ...utils import ProcessGroupManager, divide_if_divisible def tensor_parallel_split_safetensor_slice(slice, dim: int, start_end: tuple[int, int] | None = None) -> torch.Tensor: @@ -49,6 +48,9 @@ def tensor_to_dtensor( desired_placement: Placement | list[Placement] | None = None, run_check: bool = False, ) -> DTensor: + if isinstance(tensor, DTensor): + return tensor + if isinstance(current_placement, Placement): current_placement = [current_placement] @@ -69,6 +71,9 @@ def dtensor_to_tensor( desired_placement: Placement | list[Placement] | None = None, grad_placement: Placement | list[Placement] | None = None, ) -> torch.Tensor: + if not isinstance(dtensor, DTensor): + return dtensor + if desired_placement is not None: if isinstance(desired_placement, Placement): desired_placement = [desired_placement] @@ -97,7 +102,14 @@ def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefi param = module_state_dict[stripped_key] device_mesh = param.device_mesh placements = param.placements - result[key] = DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements) + + if isinstance(tensor, DTensor): + assert tensor.device_mesh == device_mesh + assert tensor.placements == placements + + result[key] = tensor + else: + result[key] = tensor_to_dtensor(tensor, device_mesh=device_mesh, current_placement=placements) return result diff --git a/dolomite_engine/hf_models/modeling_utils_TP/attention/base.py b/dolomite_engine/hf_models/modeling_utils_TP/attention/base.py index afefd9dc..4989ac5f 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/attention/base.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/attention/base.py @@ -4,11 +4,10 @@ import torch.distributed import torch.nn as nn -from ....utils import ProcessGroupManager +from ....utils import ProcessGroupManager, divide_if_divisible from ...config import CommonConfig from ...enums import AttentionHeadType, InitMethod, PositionEmbeddingType from ...modeling_utils import Attention -from ...utils import divide_if_divisible from ..dropout import Dropout_TP from ..linear import ColumnParallelLinear, ReplicatedLinear, RowParallelLinear diff --git a/dolomite_engine/hf_models/modeling_utils_TP/embedding.py b/dolomite_engine/hf_models/modeling_utils_TP/embedding.py index 38e9c568..169a3bd8 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/embedding.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/embedding.py @@ -5,9 +5,8 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Replicate, Shard -from ...utils import ProcessGroupManager +from ...utils import ProcessGroupManager, divide_if_divisible from ..modeling_utils import ParameterizedEmbedding -from ..utils import divide_if_divisible from .dtensor_module import DTensorModule from .TP import dtensor_to_tensor, get_module_placements, tensor_to_dtensor diff --git a/dolomite_engine/hf_models/modeling_utils_TP/linear.py b/dolomite_engine/hf_models/modeling_utils_TP/linear.py index 9445965c..4d3aaed9 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/linear.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/linear.py @@ -5,9 +5,8 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Partial, Replicate, Shard -from ...utils import ProcessGroupManager +from ...utils import ProcessGroupManager, divide_if_divisible from ..modeling_utils import ParameterizedLinear -from ..utils import divide_if_divisible from .dtensor_module import DTensorModule from .TP import ( all_gather_from_sequence_parallel_region, diff --git a/dolomite_engine/hf_models/modeling_utils_TP/position_embedding/alibi.py b/dolomite_engine/hf_models/modeling_utils_TP/position_embedding/alibi.py index 7ef43ae3..e8ba4977 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/position_embedding/alibi.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/position_embedding/alibi.py @@ -2,7 +2,7 @@ import torch -from ....utils import ProcessGroupManager +from ....utils import ProcessGroupManager, divide_if_divisible from ...modeling_utils import Alibi @@ -20,8 +20,7 @@ def reset_parameters(self) -> None: slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) tp_rank = ProcessGroupManager.get_tensor_parallel_rank() - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - num_heads_tp = self.num_heads // tp_world_size + num_heads_tp = divide_if_divisible(self.num_heads, ProcessGroupManager.get_tensor_parallel_world_size(), "") slopes = slopes[tp_rank * num_heads_tp : (tp_rank + 1) * num_heads_tp] self.register_buffer("slopes", slopes, persistent=False) diff --git a/dolomite_engine/hf_models/models/gpt_dolomite_TP/main.py b/dolomite_engine/hf_models/models/gpt_dolomite_TP/main.py index 6565bbfc..8fd25a8a 100644 --- a/dolomite_engine/hf_models/models/gpt_dolomite_TP/main.py +++ b/dolomite_engine/hf_models/models/gpt_dolomite_TP/main.py @@ -1,8 +1,8 @@ from ...mixins import CausalLMModelMixin_TP from .base import GPTDolomiteModel_TP, GPTDolomitePreTrainedModel_TP -from .weights import get_gpt_dolomite_tensor_parallel_state_dict +from .weights import get_gpt_dolomite_model_parallel_state_dict class GPTDolomiteForCausalLM_TP(GPTDolomitePreTrainedModel_TP, CausalLMModelMixin_TP): base_model_class = GPTDolomiteModel_TP - tensor_parallel_state_dict_function = get_gpt_dolomite_tensor_parallel_state_dict + model_parallel_state_dict_function = get_gpt_dolomite_model_parallel_state_dict diff --git a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py index a2f416a3..2ecb3e22 100644 --- a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py +++ b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py @@ -1,2 +1,2 @@ -from .shard import get_gpt_dolomite_tensor_parallel_state_dict +from .shard import get_gpt_dolomite_model_parallel_state_dict from .unshard import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts diff --git a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/shard.py b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/shard.py index 75b1b49c..e93c1727 100644 --- a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/shard.py +++ b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/shard.py @@ -1,38 +1,54 @@ import torch -from .....utils import ProcessGroupManager, SafeTensorsWeightsManager +from .....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible from ....enums import AttentionHeadType, PositionEmbeddingType from ....modeling_utils import is_glu from ....modeling_utils_TP import get_tensor_parallel_vocab_info, tensor_parallel_split_safetensor_slice -from ....utils import divide_if_divisible from ...gpt_dolomite import GPTDolomiteConfig -def get_gpt_dolomite_tensor_parallel_state_dict( +def get_gpt_dolomite_model_parallel_state_dict( config: GPTDolomiteConfig, safetensors_weights_manager: SafeTensorsWeightsManager, tensor_parallel_word_embeddings: bool, + num_pipeline_stages: int, + pipeline_stage_id: int, ) -> dict: - # word embeddings - state_dict = _get_embeddings_or_lm_head( - safetensors_weights_manager, - prefix="transformer.wte.", - vocab_size=config.vocab_size, - tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + is_first_pipeline_stage = pipeline_stage_id == 0 + is_last_pipeline_stage = pipeline_stage_id == num_pipeline_stages - 1 + + layers_per_stage = divide_if_divisible( + config.n_layer, num_pipeline_stages, "layers should be divisible by num_pipeline_stages" ) - # positional embeddings - if PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute: + layer_start_id = layers_per_stage * pipeline_stage_id + layer_end_id = layers_per_stage * (pipeline_stage_id + 1) + + state_dict = {} + + if is_first_pipeline_stage: + # word embeddings state_dict.update( _get_embeddings_or_lm_head( safetensors_weights_manager, - prefix="transformer.wpe.", - vocab_size=config.n_positions, - tensor_parallel_word_embeddings=False, + prefix="transformer.wte.", + vocab_size=config.vocab_size, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, ) ) - for layer_idx in range(config.n_layer): + # positional embeddings + if PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute: + state_dict.update( + _get_embeddings_or_lm_head( + safetensors_weights_manager, + prefix="transformer.wpe.", + vocab_size=config.n_positions, + tensor_parallel_word_embeddings=False, + ) + ) + + for layer_idx in range(layer_start_id, layer_end_id): prefix = f"transformer.h.{layer_idx}." state_dict.update(_get_layernorm(safetensors_weights_manager, prefix=prefix + "ln_1.")) @@ -59,17 +75,18 @@ def get_gpt_dolomite_tensor_parallel_state_dict( ) ) - state_dict.update(_get_layernorm(safetensors_weights_manager, prefix="transformer.ln_f.")) - - if not config.tie_word_embeddings: - state_dict.update( - _get_embeddings_or_lm_head( - safetensors_weights_manager=safetensors_weights_manager, - prefix="lm_head.", - vocab_size=config.vocab_size, - tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + if is_last_pipeline_stage: + state_dict.update(_get_layernorm(safetensors_weights_manager, prefix="transformer.ln_f.")) + + if not config.tie_word_embeddings: + state_dict.update( + _get_embeddings_or_lm_head( + safetensors_weights_manager=safetensors_weights_manager, + prefix="lm_head.", + vocab_size=config.vocab_size, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + ) ) - ) return state_dict diff --git a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py index 7094e31c..275e52e6 100644 --- a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py +++ b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py @@ -107,13 +107,13 @@ def unshard_gpt_dolomite_tensor_parallel_state_dicts( def fix_gpt_dolomite_unsharded_state_dict( - config: GPTDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str = "" + config: GPTDolomiteConfig, state_dict: dict, tensor_parallel_world_size: int, prefix: str = "" ) -> dict: state_dict[prefix + "transformer.wte.weight"] = state_dict[prefix + "transformer.wte.weight"][ : config.vocab_size, : ] state_dict = _fix_attention(config, state_dict, prefix) - state_dict = _fix_mlp(config, state_dict, tensor_parallel_size, prefix) + state_dict = _fix_mlp(config, state_dict, tensor_parallel_world_size, prefix) return state_dict @@ -268,11 +268,11 @@ def _fix_attention(config: GPTDolomiteConfig, state_dict: dict, prefix: str) -> return state_dict -def _fix_mlp(config: GPTDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str) -> dict: +def _fix_mlp(config: GPTDolomiteConfig, state_dict: dict, tensor_parallel_world_size: int, prefix: str) -> dict: if is_glu(config.activation_function): for layer_idx in range(config.n_layer): key = f"{prefix}transformer.h.{layer_idx}.mlp.c_fc.weight" - weight = state_dict[key].chunk(tensor_parallel_size) + weight = state_dict[key].chunk(tensor_parallel_world_size) weight = [w.chunk(2) for w in weight] w0 = torch.cat([w[0] for w in weight]) w1 = torch.cat([w[1] for w in weight]) @@ -280,7 +280,7 @@ def _fix_mlp(config: GPTDolomiteConfig, state_dict: dict, tensor_parallel_size: if config.add_bias: key = f"{prefix}transformer.h.{layer_idx}.mlp.c_fc.bias" - weight = state_dict[key].chunk(tensor_parallel_size) + weight = state_dict[key].chunk(tensor_parallel_world_size) weight = [w.chunk(2) for w in weight] w0 = torch.cat([w[0] for w in weight]) w1 = torch.cat([w[1] for w in weight]) diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/main.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/main.py index 9c73fd10..39edecce 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite_TP/main.py +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/main.py @@ -5,4 +5,4 @@ class MoEDolomiteForCausalLM_TP(MoEDolomitePreTrainedModel_TP, CausalLMMoEModelMixin_TP): base_model_class = MoEDolomiteModel_TP - tensor_parallel_state_dict_function = get_moe_dolomite_tensor_parallel_state_dict + model_parallel_state_dict_function = get_moe_dolomite_tensor_parallel_state_dict diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py index eac7f349..af80888b 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py @@ -7,11 +7,10 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Partial, Replicate, Shard -from .....utils import ProcessGroupManager, is_kernel_hyperdrive_available +from .....utils import ProcessGroupManager, divide_if_divisible, is_kernel_hyperdrive_available from ....enums import InitMethod from ....modeling_utils import ParameterizedTransposedLinear, get_activation_function, is_glu from ....modeling_utils_TP import Dropout_TP, DTensorModule, dtensor_to_tensor, tensor_to_dtensor -from ....utils import divide_if_divisible from ...moe_dolomite import MoEDolomiteConfig from ...moe_dolomite.moe import ScatterMoE from ...moe_dolomite.moe.scatter import ParameterizedScatteredExperts diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py index 543d0731..ae1f7292 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py @@ -114,13 +114,13 @@ def unshard_moe_dolomite_tensor_parallel_state_dicts( def fix_moe_dolomite_unsharded_state_dict( - config: MoEDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str = "" + config: MoEDolomiteConfig, state_dict: dict, tensor_parallel_world_size: int, prefix: str = "" ) -> dict: state_dict[prefix + "transformer.wte.weight"] = state_dict[prefix + "transformer.wte.weight"][ : config.vocab_size, : ] state_dict = _fix_attention(config, state_dict, prefix) - state_dict = _fix_moe(config, state_dict, tensor_parallel_size, prefix) + state_dict = _fix_moe(config, state_dict, tensor_parallel_world_size, prefix) return state_dict @@ -162,14 +162,14 @@ def _get_moe( return output -def _fix_moe(config: MoEDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str) -> dict: +def _fix_moe(config: MoEDolomiteConfig, state_dict: dict, tensor_parallel_world_size: int, prefix: str) -> dict: assert not config.add_bias if is_glu(config.activation_function): for layer_idx in range(config.n_layer): key = f"{prefix}transformer.h.{layer_idx}.moe.c_fc.weight" weight = state_dict[key] - weight = weight.chunk(tensor_parallel_size, dim=0) + weight = weight.chunk(tensor_parallel_world_size, dim=0) weight = [w.chunk(2, dim=0) for w in weight] w0 = torch.cat([w[0] for w in weight]) w1 = torch.cat([w[1] for w in weight]) diff --git a/dolomite_engine/hf_models/models/rnn_dolomite/base.py b/dolomite_engine/hf_models/models/rnn_dolomite/base.py index 362a92ac..53309971 100644 --- a/dolomite_engine/hf_models/models/rnn_dolomite/base.py +++ b/dolomite_engine/hf_models/models/rnn_dolomite/base.py @@ -3,11 +3,10 @@ from transformers import Cache from transformers.modeling_outputs import BaseModelOutputWithPast -from ....utils import is_fla_available +from ....utils import divide_if_divisible, is_fla_available from ...enums import AttentionHeadType, PositionEmbeddingType from ...mixins import BaseModelMixin, PreTrainedModelMixin from ...modeling_utils import ParameterizedEmbedding, get_normalization_function -from ...utils import divide_if_divisible from .config import RNNDolomiteConfig from .layer import RNNDolomiteBlock diff --git a/dolomite_engine/hf_models/register_hf.py b/dolomite_engine/hf_models/register_hf.py index 9d930fb0..dfc18397 100644 --- a/dolomite_engine/hf_models/register_hf.py +++ b/dolomite_engine/hf_models/register_hf.py @@ -45,14 +45,14 @@ def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForS return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES -_TENSOR_PARALLEL_CLASS_MAPPING = { +_MODEL_PARALLEL_CLASS_MAPPING = { GPTDolomiteConfig.model_type: GPTDolomiteForCausalLM_TP, MoEDolomiteConfig.model_type: MoEDolomiteForCausalLM_TP, } -def get_tensor_parallel_class(model_type: str) -> AutoModelForCausalLM: - if model_type in _TENSOR_PARALLEL_CLASS_MAPPING: - return _TENSOR_PARALLEL_CLASS_MAPPING[model_type] +def get_model_parallel_class(model_type: str) -> AutoModelForCausalLM: + if model_type in _MODEL_PARALLEL_CLASS_MAPPING: + return _MODEL_PARALLEL_CLASS_MAPPING[model_type] - raise ValueError(f"tensor parallel is not supported with `model_type` ({model_type})") + raise ValueError(f"model parallelism is not supported with `model_type` ({model_type})") diff --git a/dolomite_engine/hf_models/unshard.py b/dolomite_engine/hf_models/unshard.py index f85ad426..262852a2 100644 --- a/dolomite_engine/hf_models/unshard.py +++ b/dolomite_engine/hf_models/unshard.py @@ -41,11 +41,11 @@ def unshard_tensor_parallel_state_dicts( def fix_unsharded_state_dict( - config: CommonConfig, state_dict: dict, tensor_parallel_size: int, prefix: str = "" + config: CommonConfig, state_dict: dict, tensor_parallel_world_size: int, prefix: str = "" ) -> dict: if config.model_type in _FIX_UNSHARDED_STATE_DICT_FUNCTIONS: return _FIX_UNSHARDED_STATE_DICT_FUNCTIONS[config.model_type]( - config=config, state_dict=state_dict, tensor_parallel_size=tensor_parallel_size, prefix=prefix + config=config, state_dict=state_dict, tensor_parallel_world_size=tensor_parallel_world_size, prefix=prefix ) raise ValueError(f"unsupported `model_type` ({config.model_type})") diff --git a/dolomite_engine/hf_models/utils.py b/dolomite_engine/hf_models/utils.py index c3e3b293..c7c2bf04 100644 --- a/dolomite_engine/hf_models/utils.py +++ b/dolomite_engine/hf_models/utils.py @@ -1,22 +1,6 @@ import torch -def divide_if_divisible(dividend: int, divisor: int, msg: str) -> int: - """divide if divisible else raise an error - - Args: - dividend (int): dividend - divisor (int): divisor - msg (str): error message - - Returns: - int: result - """ - - assert dividend % divisor == 0, msg - return dividend // divisor - - def convert_padding_free_lists_to_tensors( input_ids: list[list[int]] | None = None, inputs_embeds: list[list[float]] | None = None, diff --git a/dolomite_engine/model_wrapper/__init__.py b/dolomite_engine/model_wrapper/__init__.py index a762ec7d..f0404f21 100644 --- a/dolomite_engine/model_wrapper/__init__.py +++ b/dolomite_engine/model_wrapper/__init__.py @@ -1,8 +1,7 @@ -import logging - from ..arguments import DistillationArgs, InferenceArgs, TrainingArgs, UnshardingArgs +from ..containers import ModelContainer from ..enums import Mode, TuningMethod -from ..utils import log_rank_0, run_rank_n +from ..utils import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank from .base import ModelWrapper from .distillation import ModelWrapperForDistillation from .finetuning import ModelWrapperForFinetuning @@ -19,8 +18,17 @@ } -def get_model(args: TrainingArgs | InferenceArgs | UnshardingArgs | DistillationArgs, mode: Mode) -> ModelWrapper: +def get_model_container( + args: TrainingArgs | InferenceArgs | UnshardingArgs | DistillationArgs, mode: Mode +) -> ModelContainer: tuning_method = args.tuning_args.tuning_method + num_pipeline_stages = args.distributed_args.num_pipeline_stages + + if tuning_method != TuningMethod.pretraining: + assert num_pipeline_stages == 1, "pipeline parallelism is only supported with pretraining" + + if tuning_method not in _MODEL_CLASS_MAPPING: + raise ValueError(f"unexpected tuning_method ({tuning_method})") kwargs = { "mode": mode, @@ -34,6 +42,7 @@ def get_model(args: TrainingArgs | InferenceArgs | UnshardingArgs | Distillation "use_padding_free_transformer": args.model_args.use_padding_free_transformer, "tensor_parallel_word_embeddings": args.distributed_args.tensor_parallel_word_embeddings, "sequence_parallel": args.distributed_args.sequence_parallel, + "num_pipeline_stages": num_pipeline_stages, "neft_alpha": args.research_args.neft_alpha, "trust_remote_code": args.model_args.trust_remote_code, "tokenizer_name": args.tokenizer_args.tokenizer_name, @@ -54,20 +63,9 @@ def get_model(args: TrainingArgs | InferenceArgs | UnshardingArgs | Distillation kwargs["kl_divergence_method"] = args.teacher_args.kl_divergence_method kwargs["kl_divergence_weight"] = args.teacher_args.kl_divergence_weight - if tuning_method in _MODEL_CLASS_MAPPING: - return _MODEL_CLASS_MAPPING[tuning_method](**kwargs) - - raise ValueError(f"unexpected tuning_method ({tuning_method})") - - -@run_rank_n -def log_model(model: ModelWrapper) -> None: - """print model - - Args: - model (ModelWrapper): model to print - """ + model_list = [] + for pipeline_stage_id in get_pipeline_stage_ids_on_current_rank(num_pipeline_stages): + kwargs["pipeline_stage_id"] = pipeline_stage_id + model_list.append(_MODEL_CLASS_MAPPING[tuning_method](**kwargs)) - log_rank_0(logging.INFO, "------------------------ model ------------------------") - log_rank_0(logging.INFO, model) - log_rank_0(logging.INFO, "-------------------- end of model ---------------------") + return ModelContainer(model_list) diff --git a/dolomite_engine/model_wrapper/base.py b/dolomite_engine/model_wrapper/base.py index 1e639dd4..cde4efb1 100644 --- a/dolomite_engine/model_wrapper/base.py +++ b/dolomite_engine/model_wrapper/base.py @@ -5,7 +5,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from ..enums import AttentionImplementation, Mode, MoEImplementation -from ..hf_models import get_tensor_parallel_class, is_custom_model +from ..hf_models import get_model_parallel_class, is_custom_model from ..utils import ProcessGroupManager, SafeTensorsWeightsManager, log_rank_0, string_to_torch_dtype @@ -25,6 +25,8 @@ def __init__( use_padding_free_transformer: bool, tensor_parallel_word_embeddings: bool, sequence_parallel: bool, + num_pipeline_stages: int, + pipeline_stage_id: int, neft_alpha: float | None = None, trust_remote_code: bool = False, tokenizer_name: str | None = None, @@ -43,6 +45,8 @@ def __init__( use_padding_free_transformer (bool): whether to use padding free transformer tensor_parallel_word_embeddings (bool): whether to use tensor parallel word embeddings sequence_parallel (bool): whether to use sequence parallel + num_pipeline_stages (int): number of stages for the pipeline + pipeline_stage_id (int): current pipeline stage id neft_alpha (float | None, optional): alpha parameter for NEFTune. Defaults to None. trust_remote_code (bool, optional): whether the model has remote code in the HF bucket. Defaults to False. tokenizer_name (str | None, optional): path of the model on disk or HF hub. Defaults to None. If None, the `model_name` is used for tokenizer. @@ -65,13 +69,19 @@ def __init__( self.tokenizer_name = self.model_name if tokenizer_name is None else tokenizer_name self.trust_remote_code = trust_remote_code - self.tp_rank = ProcessGroupManager.get_tensor_parallel_rank() + self.num_pipeline_stages = num_pipeline_stages + self.pipeline_stage_id = pipeline_stage_id + self.is_pipeline_parallel_enabled = self.num_pipeline_stages > 1 + + use_model_parallelism = ProcessGroupManager.is_tensor_parallel_enabled() or self.is_pipeline_parallel_enabled self._setup_config() - if ProcessGroupManager.is_tensor_parallel_enabled(): + log_rank_0(logging.INFO, f"num parameters in the model = {self.calculate_num_parameters():,}") + + if use_model_parallelism: self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - self.model_class = get_tensor_parallel_class(self.config.model_type) + self.model_class = get_model_parallel_class(self.config.model_type) if self.use_padding_free_transformer: assert is_custom_model( @@ -178,6 +188,9 @@ def _setup_model(self) -> None: model_kwargs["sequence_parallel"] = True if self.trust_remote_code: model_kwargs["trust_remote_code"] = True + if self.is_pipeline_parallel_enabled: + model_kwargs["num_pipeline_stages"] = self.num_pipeline_stages + model_kwargs["pipeline_stage_id"] = self.pipeline_stage_id if self.model_name is None: if self.tokenizer.bos_token_id is not None: @@ -191,7 +204,7 @@ def _setup_model(self) -> None: def _get_model(**extras): if self.model_name is None: - if ProcessGroupManager.is_tensor_parallel_enabled(): + if self.is_pipeline_parallel_enabled or ProcessGroupManager.is_tensor_parallel_enabled(): # avoid inferring the model class so use _from_config instead of from_config model = self.model_class._from_config(**model_kwargs, **extras) else: @@ -223,12 +236,6 @@ def _get_model(**extras): self.model = _get_model(torch_dtype=torch_dtype) - num_parameters = 0 - for param in self.model.parameters(): - num_parameters += param.numel() - - log_rank_0(logging.INFO, f"num parameters in the model = {num_parameters:,}") - def _override_embedding_forward_with_neft_forward(self, neft_alpha: float) -> None: if not hasattr(self.model, "get_input_embeddings"): raise Exception( @@ -251,6 +258,19 @@ def _noisy_forward(x: torch.Tensor) -> torch.Tensor: # overrides the forward function of torch.nn.Embedding self.model.get_input_embeddings().forward = _noisy_forward + def calculate_num_parameters(self) -> int: + with torch.device("meta"): + if self.model_name is None: + model = self.model_class.from_config(config=self.config) + else: + model = self.model_class.from_pretrained(pretrained_model_name_or_path=self.model_name) + + num_parameters = 0 + for param in model.parameters(): + num_parameters += param.numel() + + return num_parameters + def has_teacher_model(self) -> bool: return False diff --git a/dolomite_engine/model_wrapper/distillation.py b/dolomite_engine/model_wrapper/distillation.py index 87182993..9a80dc0f 100644 --- a/dolomite_engine/model_wrapper/distillation.py +++ b/dolomite_engine/model_wrapper/distillation.py @@ -26,6 +26,8 @@ def __init__( sequence_parallel: bool, micro_batch_size: int, sequence_length: int, + num_pipeline_stages: int, + pipeline_stage_id: int, teacher_model_name: str | None, teacher_model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, teacher_model_dtype: torch.dtype, @@ -51,6 +53,8 @@ def __init__( use_padding_free_transformer (bool): whether to use padding free transformer tensor_parallel_word_embeddings (bool): whether to use tensor parallel word embeddings sequence_parallel (bool): whether to use sequence parallel + num_pipeline_stages (int): number of stages for the pipeline + pipeline_stage_id (int): current pipeline stage id micro_batch_size (int): micro batch size for pretraining sequence_length (int): sequence length for pretraining neft_alpha (float | None, optional): alpha parameter for NEFTune. Defaults to None. @@ -75,11 +79,14 @@ def __init__( dtype=dtype, efficient_initialization=efficient_initialization, attention_implementation=attention_implementation, + moe_implementation=moe_implementation, use_padding_free_transformer=use_padding_free_transformer, tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, sequence_parallel=sequence_parallel, micro_batch_size=micro_batch_size, sequence_length=sequence_length, + num_pipeline_stages=num_pipeline_stages, + pipeline_stage_id=pipeline_stage_id, neft_alpha=neft_alpha, trust_remote_code=trust_remote_code, tokenizer_name=tokenizer_name, diff --git a/dolomite_engine/model_wrapper/finetuning.py b/dolomite_engine/model_wrapper/finetuning.py index f350b1e7..3d483c7f 100644 --- a/dolomite_engine/model_wrapper/finetuning.py +++ b/dolomite_engine/model_wrapper/finetuning.py @@ -28,13 +28,14 @@ def forward(self, batch: dict) -> dict: def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict: device = torch.cuda.current_device() + is_tp_first_rank = ProcessGroupManager.is_tensor_parallel_first_rank() tp_source_rank = ProcessGroupManager.get_tensor_parallel_first_rank() tp_group = ProcessGroupManager.get_tensor_parallel_group() if self.use_padding_free_transformer: keys = ["input_ids", "position_ids", "labels", "cu_seqlens", "max_seqlen"] - if self.tp_rank == 0: + if is_tp_first_rank: metadata = torch.tensor([batch["cu_seqlens"].numel(), batch["input_ids"].numel()], device=device) else: metadata = torch.empty(2, dtype=torch.long, device=device) @@ -42,7 +43,7 @@ def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict: torch.distributed.broadcast(metadata, src=tp_source_rank, group=tp_group) cu_seqlens_num_elements, input_ids_num_elements = metadata - if self.tp_rank != 0: + if not is_tp_first_rank: batch = { "input_ids": torch.empty(input_ids_num_elements, dtype=torch.long, device=device), "position_ids": torch.empty(input_ids_num_elements, dtype=torch.long, device=device), @@ -53,10 +54,10 @@ def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict: else: keys = ["input_ids", "attention_mask", "labels"] - batch_shape = batch["input_ids"].shape if self.tp_rank == 0 else None + batch_shape = batch["input_ids"].shape if is_tp_first_rank else None batch_shape = Communication.broadcast_object(batch_shape, src=tp_source_rank, group=tp_group) - if self.tp_rank != 0: + if not is_tp_first_rank: batch = {key: torch.empty(batch_shape, dtype=torch.long, device=device) for key in keys} for key in keys: diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index d055eed0..fb997b75 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -29,6 +29,8 @@ def __init__( sequence_parallel: bool, micro_batch_size: int, sequence_length: int, + num_pipeline_stages: int, + pipeline_stage_id: int, neft_alpha: float | None = None, trust_remote_code: bool = False, tokenizer_name: str | None = None, @@ -51,6 +53,8 @@ def __init__( sequence_parallel (bool): whether to use sequence parallel micro_batch_size (int): micro batch size for pretraining sequence_length (int): sequence length for pretraining + num_pipeline_stages (int): number of stages for the pipeline + pipeline_stage_id (int): current pipeline stage id neft_alpha (float | None, optional): alpha parameter for NEFTune. Defaults to None. trust_remote_code (bool, optional): whether the model has remote code in the HF bucket. Defaults to False. tokenizer_name (str | None, optional): path of the model on disk or HF hub. Defaults to None. If None, the `model_name` is used for tokenizer. @@ -76,12 +80,18 @@ def __init__( use_padding_free_transformer=use_padding_free_transformer, tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, sequence_parallel=sequence_parallel, + num_pipeline_stages=num_pipeline_stages, + pipeline_stage_id=pipeline_stage_id, neft_alpha=neft_alpha, trust_remote_code=trust_remote_code, tokenizer_name=tokenizer_name, additional_special_tokens=additional_special_tokens, ) + if self.is_pipeline_parallel_enabled: + assert not self.reset_attention_mask, "reset_attention_mask is not supported with pipeline parallelism" + assert not self.reset_position_ids, "reset_position_ids is not supported with pipeline parallelism" + def forward(self, batch: dict) -> dict: """forward function for a batch @@ -97,11 +107,25 @@ def forward(self, batch: dict) -> dict: # instead of (sequence_length), so we need to trim the input_ids before forward pass. # transformers does forward pass before however and then trims the tokens. + if isinstance(batch, torch.Tensor): + batch = {"text": batch} + input_ids, labels = self._prepare_inputs_ids_and_labels_for_forward(batch) batch = self._prepare_model_inputs(input_ids) model_outputs = self.model(**batch, return_dict=True) - logits: torch.Tensor = model_outputs.logits + + # without pipeline parallel, we compute the loss outside + if not self.is_pipeline_parallel_enabled: + model_outputs = self.get_loss(model_outputs, labels) + + return model_outputs + + def get_loss(self, model_outputs, labels: torch.Tensor) -> torch.Tensor: + if isinstance(model_outputs, torch.Tensor): + logits = model_outputs + else: + logits: torch.Tensor = model_outputs.logits if self.upcast_logits_for_loss: logits = logits.float() @@ -133,6 +157,20 @@ def forward(self, batch: dict) -> dict: return output + def broadcast_tensor_parallel_input(self, tokens: dict, shape: tuple[int]) -> torch.Tensor: + if ProcessGroupManager.is_tensor_parallel_first_rank(): + tokens = tokens.to(torch.cuda.current_device()) + else: + tokens = torch.empty(shape, dtype=torch.long, device=torch.cuda.current_device()) + + torch.distributed.broadcast( + tokens, + src=ProcessGroupManager.get_tensor_parallel_first_rank(), + group=ProcessGroupManager.get_tensor_parallel_group(), + ) + + return tokens + def _prepare_model_inputs(self, input_ids: torch.Tensor) -> dict: batch = {} @@ -151,7 +189,8 @@ def _prepare_model_inputs(self, input_ids: torch.Tensor) -> dict: cu_seqlens = cu_seqlens.to(torch.int32) seqlen = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = seqlen.max() + # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers + max_seqlen = seqlen.max().item() if self.reset_position_ids: position_ids = torch.cat( @@ -165,8 +204,7 @@ def _prepare_model_inputs(self, input_ids: torch.Tensor) -> dict: position_ids = self.position_ids batch["cu_seqlens"] = cu_seqlens - # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers - batch["max_seqlen"] = max_seqlen.item() + batch["max_seqlen"] = max_seqlen batch["position_ids"] = position_ids batch["input_ids"] = input_ids @@ -176,28 +214,29 @@ def _prepare_model_inputs(self, input_ids: torch.Tensor) -> dict: return batch - def _prepare_inputs_ids_and_labels_for_forward(self, batch: dict) -> torch.Tensor: - if ProcessGroupManager.is_tensor_parallel_enabled(): - tp_source_rank = ProcessGroupManager.get_tensor_parallel_first_rank() - tp_group = ProcessGroupManager.get_tensor_parallel_group() + def _prepare_inputs_ids_and_labels_for_forward(self, batch: dict) -> tuple[torch.Tensor]: + if self.is_pipeline_parallel_enabled: + # when using pipeline parallel, we broadcast the input outside the model function + tokens = batch["text"] + tokens = tokens.to(torch.cuda.current_device()) - if self.tp_rank == 0: - tokens: torch.Tensor = batch["text"] - tokens = tokens.to(torch.cuda.current_device()) + if self.pipeline_stage_id == 0: + input_ids = tokens[:, :-1] else: - tokens = torch.empty( - (self.micro_batch_size, self.sequence_length + 1), - dtype=torch.long, - device=torch.cuda.current_device(), - ) + input_ids = tokens - torch.distributed.broadcast(tokens, src=tp_source_rank, group=tp_group) + labels = None else: - tokens: torch.Tensor = batch["text"] - tokens = tokens.to(torch.cuda.current_device()) + if ProcessGroupManager.is_tensor_parallel_enabled(): + tokens = self.broadcast_tensor_parallel_input( + None if batch is None else batch["text"], (self.micro_batch_size, self.sequence_length + 1) + ) + else: + tokens = batch["text"] + tokens = tokens.to(torch.cuda.current_device()) - input_ids = tokens[:, :-1] - labels = tokens[:, 1:] + input_ids = tokens[:, :-1] + labels = tokens[:, 1:] return input_ids, labels @@ -219,12 +258,7 @@ def _setup_model(self) -> None: ), persistent=False, ) - self.register_buffer( - "max_seqlen", - torch.tensor(self.sequence_length, device=torch.cuda.current_device()), - persistent=False, - ) - + self.max_seqlen = self.sequence_length if self.reset_position_ids: assert self.reset_attention_mask, "reset_attention_mask should be specified with reset_position_ids" else: diff --git a/dolomite_engine/optimization/__init__.py b/dolomite_engine/optimization/__init__.py index ea88f771..e8f88e83 100644 --- a/dolomite_engine/optimization/__init__.py +++ b/dolomite_engine/optimization/__init__.py @@ -1,2 +1,2 @@ -from .optimizer import get_optimizer, log_optimizer -from .scheduler import get_scheduler +from .optimizer import get_optimizer_container +from .scheduler import get_scheduler_container diff --git a/dolomite_engine/optimization/optimizer.py b/dolomite_engine/optimization/optimizer.py index e3ba1bfe..21143a7e 100644 --- a/dolomite_engine/optimization/optimizer.py +++ b/dolomite_engine/optimization/optimizer.py @@ -1,6 +1,3 @@ -import logging - -from torch.optim import Optimizer from torch.optim.adadelta import Adadelta as TorchAdadelta from torch.optim.adagrad import Adagrad as TorchAdagrad from torch.optim.adam import Adam as TorchAdam @@ -14,10 +11,10 @@ from torch.optim.rprop import Rprop as TorchRprop from torch.optim.sgd import SGD as TorchSGD +from ..containers import ModelContainer, OptimizerContainer from ..enums import ParamsGroupMethod -from ..model_wrapper import ModelWrapper -from ..utils import is_apex_available, log_rank_0, run_rank_n -from .params_group import get_param_groups +from ..utils import is_apex_available +from .params_group import get_param_groups_list if is_apex_available(): @@ -54,22 +51,22 @@ } -def get_optimizer( +def get_optimizer_container( optimizer_class_name: str, optimizer_class_args: dict, - model: ModelWrapper, + model_container: ModelContainer, params_group_method: ParamsGroupMethod, -) -> Optimizer: - """setup optimizer for the model +) -> OptimizerContainer: + """setup list of optimizers for the model Args: optimizer_class_name (str): optimizer class name optimizer_class_args (dict): args for the optimizer class - model (ModelWrapper): model + model_container (ModelContainer): model container params_group_method (ParamsGroupMethod): the params grouping to use Returns: - Optimizer: an optimizer + OptimizerContainer: optimizer container """ if optimizer_class_name not in _OPTIMIZER_CLASSES: @@ -79,20 +76,7 @@ def get_optimizer( if optimizer_class is None: raise ImportError("relevant package for the optimizer is not installed") - params_group = get_param_groups(model, optimizer_class_args, params_group_method) - optimizer = optimizer_class(params_group, **optimizer_class_args) - - return optimizer - - -@run_rank_n -def log_optimizer(optimizer: Optimizer) -> None: - """print optimizer - - Args: - optimizer (Optimizer): optimizer to print - """ + params_groups_list = get_param_groups_list(model_container, optimizer_class_args, params_group_method) + optimizer_list = [optimizer_class(params_group, **optimizer_class_args) for params_group in params_groups_list] - log_rank_0(logging.INFO, "------------------------ optimizer ------------------------") - log_rank_0(logging.INFO, optimizer) - log_rank_0(logging.INFO, "-------------------- end of optimizer ---------------------") + return OptimizerContainer(optimizer_list) diff --git a/dolomite_engine/optimization/params_group.py b/dolomite_engine/optimization/params_group.py index fa32b3af..e9bf38ab 100644 --- a/dolomite_engine/optimization/params_group.py +++ b/dolomite_engine/optimization/params_group.py @@ -2,6 +2,7 @@ import torch.nn as nn +from ..containers import ModelContainer from ..enums import ParamsGroupMethod from ..hf_models import ( GPTDolomiteForCausalLM, @@ -51,12 +52,17 @@ def get_normal_group_with_names(model: ModelWrapper, optimizer_class_args: dict) list(model.parameters()) ), "params in groups don't sum up to total parameters" - trainable_parameters_or_param_groups = [ - {"params": list(normal_params.values())}, - {"params": list(no_weight_decay_params.values()), "weight_decay": 0}, - ] + trainable_parameters_or_param_groups = [] + names = {} - names = {"normal": list(normal_params.keys()), "no_weight_decay": list(no_weight_decay_params.keys())} + if len(normal_params) > 0: + trainable_parameters_or_param_groups.append({"params": list(normal_params.values())}) + names["normal"] = list(normal_params.keys()) + if len(no_weight_decay_params) > 0: + trainable_parameters_or_param_groups.append( + {"params": list(no_weight_decay_params.values()), "weight_decay": 0} + ) + names["no_weight_decay"] = list(no_weight_decay_params.keys()) return trainable_parameters_or_param_groups, names @@ -112,17 +118,22 @@ def get_mup_group_with_names(model: ModelWrapper, optimizer_class_args: dict) -> list(model.parameters()) ), "params in groups don't sum up to total parameters" - trainable_parameters_or_param_groups = [ - {"params": list(normal_params.values())}, - {"params": list(no_weight_decay_params.values()), "weight_decay": 0}, - {"params": list(mup_params.values()), "lr": optimizer_class_args["lr"] / model.config.m_width}, - ] - - names = { - "normal": list(normal_params.keys()), - "no_weight_decay": list(no_weight_decay_params.keys()), - "mup": list(mup_params.keys()), - } + trainable_parameters_or_param_groups = [] + names = {} + + if len(normal_params) > 0: + trainable_parameters_or_param_groups.append({"params": list(normal_params.values())}) + names["normal"] = list(normal_params.keys()) + if len(no_weight_decay_params) > 0: + trainable_parameters_or_param_groups.append( + {"params": list(no_weight_decay_params.values()), "weight_decay": 0} + ) + names["no_weight_decay"] = list(no_weight_decay_params.keys()) + if len(mup_params) > 0: + trainable_parameters_or_param_groups.append( + {"params": list(mup_params.values()), "lr": optimizer_class_args["lr"] / model.config.m_width} + ) + names["mup"] = list(mup_params.keys()) return trainable_parameters_or_param_groups, names @@ -133,10 +144,10 @@ def get_mup_group_with_names(model: ModelWrapper, optimizer_class_args: dict) -> } -def get_param_groups( - model: ModelWrapper, optimizer_class_args: dict, params_group_method: ParamsGroupMethod | None -) -> list[dict]: - if params_group_method in _PARAM_GROUPS: - return _PARAM_GROUPS[params_group_method](model, optimizer_class_args)[0] +def get_param_groups_list( + model_container: ModelContainer, optimizer_class_args: dict, params_group_method: ParamsGroupMethod | None +) -> list[list[dict]]: + if params_group_method not in _PARAM_GROUPS: + raise ValueError(f"unexpected `params_group_method` {params_group_method}") - raise ValueError(f"unexpected `params_group_method` {params_group_method}") + return [_PARAM_GROUPS[params_group_method](model, optimizer_class_args)[0] for model in model_container] diff --git a/dolomite_engine/optimization/scheduler.py b/dolomite_engine/optimization/scheduler.py index f6d31ff5..c3c5acef 100644 --- a/dolomite_engine/optimization/scheduler.py +++ b/dolomite_engine/optimization/scheduler.py @@ -3,6 +3,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR +from ..containers import LRSchedulerContainer, OptimizerContainer from ..enums import LRDecaySchedule @@ -190,8 +191,8 @@ def _lr_lambda(self, num_steps: int) -> float: } -def get_scheduler( - optimizer: Optimizer, +def get_scheduler_container( + optimizer_container: OptimizerContainer, num_warmup_steps: int, num_constant_steps: int, num_decay_steps: int, @@ -204,16 +205,18 @@ def get_scheduler( if lr_decay_style not in _LR_SCHEDULER_CLASSES: raise ValueError(f"invalid lr_decay_style ({lr_decay_style})") - lr_scheduler_class = _LR_SCHEDULER_CLASSES[lr_decay_style] - - lr_scheduler = lr_scheduler_class( - optimizer, - num_warmup_steps=num_warmup_steps, - num_constant_steps=num_constant_steps, - num_decay_steps=num_decay_steps, - num_training_steps=num_training_steps, - lr_decay_factor=lr_decay_factor, - **extra_lr_scheduler_args, - last_epoch=last_epoch, - ) - return lr_scheduler + lr_scheduler_list = [ + _LR_SCHEDULER_CLASSES[lr_decay_style]( + optimizer, + num_warmup_steps=num_warmup_steps, + num_constant_steps=num_constant_steps, + num_decay_steps=num_decay_steps, + num_training_steps=num_training_steps, + lr_decay_factor=lr_decay_factor, + **extra_lr_scheduler_args, + last_epoch=last_epoch, + ) + for optimizer in optimizer_container + ] + + return LRSchedulerContainer(lr_scheduler_list) diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index 4a13117a..450b6b9b 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -5,20 +5,20 @@ import torch from torch.distributed._tensor.api import DTensor +from torch.distributed.pipelining.schedules import _PipelineSchedule from torch.distributed.tensor.parallel import loss_parallel -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from transformers import set_seed from .arguments import TrainingArgs, get_args from .checkpointing import load_checkpoint_for_training, save_checkpoint from .communication import Communication +from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container from .data import get_megatron_gpt_dataloaders, get_next_batch -from .distributed import wrap_model_for_distributed_training +from .distributed import wrap_model_container_for_distributed_training from .enums import FP8Backend, Mode, TuningMethod -from .model_wrapper import ModelWrapperForPretraining, get_model, log_model -from .optimization import get_optimizer, get_scheduler, log_optimizer +from .model_wrapper import get_model_container +from .optimization import get_optimizer_container, get_scheduler_container from .train_utils import all_reduce_metrics_tracker, get_model_tflops, get_torch_profiler, track_metrics, train_step from .utils import ( ExperimentsTracker, @@ -74,9 +74,10 @@ def track_val_metrics( def train( args: TrainingArgs, - model: ModelWrapperForPretraining, - optimizer: Optimizer, - lr_scheduler: LambdaLR, + model_container: ModelContainer, + pipeline_schedule: _PipelineSchedule, + optimizer_container: OptimizerContainer, + lr_scheduler_container: LRSchedulerContainer, train_dataloader: DataLoader, val_dataloaders: list[DataLoader], test_dataloaders: list[DataLoader], @@ -87,9 +88,10 @@ def train( Args: args (TrainingArgs): training args - model (ModelWrapperForPretraining): model - optimizer (Optimizer): optimizer - lr_scheduler (LRScheduler): learning rate scheduler + model_container (ModelContainer): container of models + pipeline_schedule (_PipelineSchedule): pipeline schedule + optimizer_container (OptimizerContainer): container of optimizers + lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers train_dataloader (DataLoader): training dataloader val_dataloaders (list[DataLoader]): validation dataloaders test_dataloaders (list[DataLoader]): test dataloaders @@ -111,17 +113,16 @@ def train( if val_weighted_split_paths is not None: group_names = [key for key in val_weighted_split_paths.keys()[0]] - model.train() + model_container.train() if eval_during_training: eval_steps = args.datasets[0].class_args.get("eval_steps") - evaluate(val_dataloaders, model, starting_iteration, experiments_tracker, eval_steps, group_names) + evaluate(val_dataloaders, model_container, starting_iteration, experiments_tracker, eval_steps, group_names) micro_batch_size = args.training_parameters.micro_batch_size sequence_length = args.datasets[0].class_args.get("sequence_length") - global_batch_size = ( - micro_batch_size * gradient_accumulation_steps * ProcessGroupManager.get_data_parallel_world_size() - ) + local_batch_size = micro_batch_size * gradient_accumulation_steps + global_batch_size = local_batch_size * ProcessGroupManager.get_data_parallel_world_size() tokens_per_batch = global_batch_size * sequence_length dp_world_size = ProcessGroupManager.get_data_parallel_world_size() @@ -130,7 +131,7 @@ def train( model_flops = ( get_model_tflops( model_class=args.model_args.model_class, - config=model.config, + config=model_container[0].config, batch_size=global_batch_size, sequence_length=sequence_length, gradient_checkpointing_method=args.distributed_args.gradient_checkpointing_method, @@ -166,15 +167,19 @@ def train( steps_since_start_time += 1 loss_step_dict = train_step( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, + model_container=model_container, + pipeline_schedule=pipeline_schedule, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, train_dataloader=train_dataloader, gradient_accumulation_steps=gradient_accumulation_steps, gradient_clipping=gradient_clipping, forward_context=forward_context, backward_context=backward_context, sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, + is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1, + batch_size=local_batch_size, + sequence_length=sequence_length, ) metrics_tracker = metrics_tracker + loss_step_dict @@ -188,7 +193,7 @@ def train( time_elapsed = time.perf_counter() - start_time step_time = time_elapsed / steps_since_start_time - metrics_tracker["learning_rate"] = lr_scheduler.get_lr()[0] + metrics_tracker["learning_rate"] = lr_scheduler_container[0].get_lr()[0] if model_flops is not None: metrics_tracker["FLOPs"] = model_flops * steps_since_start_time / time_elapsed @@ -208,25 +213,27 @@ def train( metrics_tracker = MetricsTrackingDict({}) if eval_during_training and (global_step % eval_interval == 0 or global_step == num_training_steps): - evaluate(val_dataloaders, model, global_step, experiments_tracker, eval_steps, group_names) + evaluate(val_dataloaders, model_container, global_step, experiments_tracker, eval_steps, group_names) if global_step % save_interval == 0 or global_step == num_training_steps: save_checkpoint( - args, - model, - optimizer, - lr_scheduler, - None, - experiments_tracker, - global_step, - {"consumed_samples": global_step * micro_batch_size * gradient_accumulation_steps * dp_world_size}, + args=args, + model_container=model_container, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, + train_dataloader=None, + experiments_tracker=experiments_tracker, + iteration=global_step, + metadata={ + "consumed_samples": global_step * micro_batch_size * gradient_accumulation_steps * dp_world_size + }, ) start_time = time.perf_counter() steps_since_start_time = 0 if eval_during_training: - evaluate(test_dataloaders, model, global_step, experiments_tracker, eval_steps, group_names) + evaluate(test_dataloaders, model_container, global_step, experiments_tracker, eval_steps, group_names) if torch_profiler is not None: torch_profiler.__exit__() @@ -235,7 +242,7 @@ def train( @torch.no_grad() def evaluate( val_dataloaders: list[DataLoader], - model: ModelWrapperForPretraining, + model_container: ModelContainer, global_step: int, experiments_tracker: ExperimentsTracker, eval_steps: int, @@ -245,7 +252,7 @@ def evaluate( Args: val_dataloaders (list[DataLoader]): list of validation dataloaders - model (ModelWrapperForPretraining): model + model_container (ModelContainer): container of models global_step (int): global step during training experiments_tracker (ExperimentsTracker): metrics tracker eval_steps (int): number of steps to run eval for @@ -255,11 +262,14 @@ def evaluate( MetricsTrackingDict: metrics tracker """ + assert len(model_container) == 1 + model = model_container[0] + if ProcessGroupManager.is_tensor_parallel_enabled(): # other tensor parallel ranks need to be told if val dataloader is None or not is_val_dataloader_none = ( val_dataloaders is None or len(val_dataloaders) == 0 - if ProcessGroupManager.get_tensor_parallel_rank() == 0 + if ProcessGroupManager.is_tensor_parallel_first_rank() else None ) is_val_dataloader_none = Communication.broadcast_object( @@ -323,7 +333,8 @@ def main(mode: Mode = Mode.training) -> None: # initialize distributed with nccl for multi-node communications init_distributed( - tensor_parallel_size=args.distributed_args.tensor_parallel_size, + tensor_parallel_world_size=args.distributed_args.tensor_parallel_world_size, + pipeline_parallel_world_size=args.distributed_args.pipeline_parallel_world_size, data_parallel_size=args.distributed_args.data_parallel_size, data_parallel_replication_world_size=args.distributed_args.zero_topology.data_parallel_replication_world_size, data_parallel_sharding_world_size=args.distributed_args.zero_topology.data_parallel_sharding_world_size, @@ -331,20 +342,24 @@ def main(mode: Mode = Mode.training) -> None: timeout_minutes=args.distributed_args.timeout_minutes, use_async_tensor_parallel=args.distributed_args.use_async_tensor_parallel, ) + set_seed(args.random_args.seed) - model = get_model(args, mode) - model = wrap_model_for_distributed_training(args, model) + if mode == Mode.distillation: + assert args.distributed_args.num_pipeline_stages == 1, "pipeline parallel is not supported with distillation" + + model_container = get_model_container(args, mode) + model_container, pipeline_schedule = wrap_model_container_for_distributed_training(args, model_container) - optimizer = get_optimizer( + optimizer_container = get_optimizer_container( optimizer_class_name=args.optimizer_args.class_name, optimizer_class_args=args.optimizer_args.class_args, - model=model, + model_container=model_container, params_group_method=args.optimizer_args.params_group_method, ) - lr_scheduler = get_scheduler( - optimizer=optimizer, + lr_scheduler_container = get_scheduler_container( + optimizer_container=optimizer_container, num_warmup_steps=args.lr_scheduler_args.num_warmup_steps, num_constant_steps=args.lr_scheduler_args.num_constant_steps, num_decay_steps=args.lr_scheduler_args.num_decay_steps, @@ -354,15 +369,14 @@ def main(mode: Mode = Mode.training) -> None: extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args, ) - log_model(model) - log_optimizer(optimizer) + log_model_optimizer_container(model_container, optimizer_container) starting_iteration = 0 metadata = None experiments_tracker_state_dict = None if args.load_args is not None: starting_iteration, metadata, experiments_tracker_state_dict = load_checkpoint_for_training( - args, model, optimizer, lr_scheduler, None + args, model_container, optimizer_container, lr_scheduler_container, None ) # metadata field contains the dataloader state so we need to reset it here @@ -370,7 +384,7 @@ def main(mode: Mode = Mode.training) -> None: metadata["consumed_samples"] = 0 train_dataloader, val_dataloaders, test_dataloaders = get_megatron_gpt_dataloaders( - args, model.tokenizer, 0 if metadata is None else metadata["consumed_samples"] + args, model_container[0].tokenizer, 0 if metadata is None else metadata["consumed_samples"] ) experiments_tracker = ExperimentsTracker( @@ -385,9 +399,10 @@ def main(mode: Mode = Mode.training) -> None: # main training loop train( args, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, + model_container=model_container, + pipeline_schedule=pipeline_schedule, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, test_dataloaders=test_dataloaders, diff --git a/dolomite_engine/train_utils.py b/dolomite_engine/train_utils.py index 84440009..e78b2854 100644 --- a/dolomite_engine/train_utils.py +++ b/dolomite_engine/train_utils.py @@ -2,22 +2,184 @@ from contextlib import AbstractContextManager, nullcontext import torch -from torch.distributed import ReduceOp from torch.distributed._tensor.api import DTensor +from torch.distributed.pipelining.schedules import _PipelineSchedule from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM +from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer from .data import ResumableDataLoader, get_next_batch from .enums import GradientCheckpointingMethod from .hf_models import is_custom_model from .hf_models.modeling_utils import is_glu -from .model_wrapper import ModelWrapperForFinetuning -from .utils import ExperimentsTracker, MetricsTrackingDict, ProcessGroupManager, log_rank_0 +from .model_wrapper import ModelWrapper +from .utils import ExperimentsTracker, MetricsTrackingDict, ProcessGroupManager, log_metrics def train_step( - model: ModelWrapperForFinetuning, + model_container: ModelContainer, + pipeline_schedule: _PipelineSchedule, + optimizer_container: OptimizerContainer, + lr_scheduler_container: LRSchedulerContainer, + train_dataloader: ResumableDataLoader, + gradient_accumulation_steps: int, + gradient_clipping: float, + forward_context: AbstractContextManager, + backward_context: AbstractContextManager, + sync_every_gradient_accumulation_step: bool, + is_pipeline_parallel_enabled: bool, + batch_size: int, + sequence_length: int, +) -> MetricsTrackingDict: + """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary + + Args: + model_container (ModelContainer): container of models + pipeline_schedule (_PipelineSchedule): pipeline schedule + optimizer_container (OptimizerContainer): container of optimizers + lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers + train_dataloader (ResumableDataLoader): training dataloader + gradient_accumulation_steps (int): gradient accumulation steps + gradient_clipping (float): gradient clipping value + forward_context (AbstractContextManager): a context that is used for every model forward call + backward_context (AbstractContextManager): a context that is used for every model backward call + sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step + is_pipeline_parallel_enabled (bool): whether to use pipeline parallel + batch_size (int): batch size + sequence_length (int): sequence length + + Returns: + MetricsTrackingDict: metrics to track + """ + + assert len(model_container) == len(optimizer_container) + assert len(optimizer_container) == len(lr_scheduler_container) + + if is_pipeline_parallel_enabled: + metrics_tracker = _train_step_with_pipeline_parallel( + model_container=model_container, + pipeline_schedule=pipeline_schedule, + optimizer_container=optimizer_container, + lr_scheduler_container=lr_scheduler_container, + train_dataloader=train_dataloader, + gradient_accumulation_steps=gradient_accumulation_steps, + gradient_clipping=gradient_clipping, + batch_size=batch_size, + sequence_length=sequence_length, + ) + else: + assert len(model_container) == 1 + + metrics_tracker = _train_step_without_pipeline_parallel( + model=model_container[0], + optimizer=optimizer_container[0], + lr_scheduler=lr_scheduler_container[0], + train_dataloader=train_dataloader, + gradient_accumulation_steps=gradient_accumulation_steps, + gradient_clipping=gradient_clipping, + forward_context=forward_context, + backward_context=backward_context, + sync_every_gradient_accumulation_step=sync_every_gradient_accumulation_step, + ) + + return metrics_tracker + + +def _train_step_with_pipeline_parallel( + model_container: ModelContainer, + pipeline_schedule: _PipelineSchedule, + optimizer_container: OptimizerContainer, + lr_scheduler_container: LRSchedulerContainer, + train_dataloader: ResumableDataLoader, + gradient_accumulation_steps: int, + gradient_clipping: float, + batch_size: int, + sequence_length: int, +) -> MetricsTrackingDict: + """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary + + Args: + model_container (ModelContainer): container of models + pipeline_schedule (_PipelineSchedule): pipeline schedule + optimizer_container (OptimizerContainer): container of optimizers + lr_scheduler_container (LRSchedulerContainer): container of learning rate schedulers + train_dataloader (ResumableDataLoader): training dataloader + gradient_accumulation_steps (int): gradient accumulation steps + gradient_clipping (float): gradient clipping value + batch_size (int): batch size + sequence_length (int): sequence length + + Returns: + MetricsTrackingDict: metrics to track + """ + + fsdp_algorithm = 2 if hasattr(model_container[0], "set_requires_gradient_sync") else 1 + grad_norm = [] + + optimizer_container.zero_grad() + + batch = get_next_batch(train_dataloader) + + if ProcessGroupManager.is_tensor_parallel_first_rank(): + batch = batch["text"] + + batch = model_container[0].broadcast_tensor_parallel_input(batch, (batch_size, sequence_length + 1)) + + is_first_pipeline_rank = ProcessGroupManager.get_pipeline_parallel_rank() == 0 + is_last_pipeline_rank = ( + ProcessGroupManager.get_pipeline_parallel_rank() == ProcessGroupManager.get_pipeline_parallel_world_size() - 1 + ) + + if is_first_pipeline_rank: + pipeline_schedule.step(batch) + elif is_last_pipeline_rank: + losses = [] + labels = batch[:, 1:] + pipeline_schedule.step(target=labels, losses=losses) + else: + pipeline_schedule.step() + + if gradient_clipping is not None: + for model in model_container: + if fsdp_algorithm == 1: + grad_norm.append(model.clip_grad_norm_(gradient_clipping)) + else: + grad_norm.append(torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)) + + optimizer_container.step() + lr_scheduler_container.step() + + metrics_tracker = MetricsTrackingDict({}) + + with torch.inference_mode(): + grad_norm = sum(grad_norm) + if not isinstance(grad_norm, torch.Tensor): + grad_norm = torch.tensor(grad_norm, device=torch.cuda.current_device()) + elif isinstance(grad_norm, DTensor): + grad_norm = grad_norm.to_local() + + torch.distributed.all_reduce(grad_norm, group=ProcessGroupManager.get_pipeline_parallel_group()) + + if is_last_pipeline_rank: + losses = sum(losses) + + metrics_tracker = metrics_tracker + {"loss": losses, "grad_norm": grad_norm} + metrics_tracker = metrics_tracker / gradient_accumulation_steps + + metrics_tracker["grad_norm"] = grad_norm + + for key in metrics_tracker: + if isinstance(metrics_tracker[key], DTensor): + metrics_tracker[key] = metrics_tracker[key].to_local() + + metrics_tracker = all_reduce_metrics_tracker(metrics_tracker) + + return metrics_tracker + + +def _train_step_without_pipeline_parallel( + model: ModelWrapper, optimizer: Optimizer, lr_scheduler: LambdaLR, train_dataloader: ResumableDataLoader, @@ -30,7 +192,7 @@ def train_step( """runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary Args: - model (ModelWrapperForFinetuning): model + model (ModelWrapper): model optimizer (Optimizer): optimizer lr_scheduler (LamdaLR): learning rate scheduler train_dataloader (ResumableDataLoader): training dataloader @@ -63,13 +225,13 @@ def train_step( with forward_context(): loss_micro_step_dict = model(batch) - with torch.inference_mode(): - metrics_tracker = metrics_tracker + loss_micro_step_dict - # compute gradients with backward_context(): loss_micro_step_dict["loss"].backward() + with torch.inference_mode(): + metrics_tracker = metrics_tracker + loss_micro_step_dict + if fsdp_algorithm == 2: model.set_requires_gradient_sync(True) @@ -77,13 +239,13 @@ def train_step( with forward_context(): loss_micro_step_dict = model(batch) - with torch.inference_mode(): - metrics_tracker = metrics_tracker + loss_micro_step_dict - # compute gradients with backward_context(): loss_micro_step_dict["loss"].backward() + with torch.inference_mode(): + metrics_tracker = metrics_tracker + loss_micro_step_dict + if gradient_clipping is not None: if fsdp_algorithm == 1: grad_norm = model.clip_grad_norm_(gradient_clipping) @@ -145,7 +307,7 @@ def track_metrics( else: message += f", {context}-{key} = {metrics_tracker[key]:.4f}" - log_rank_0(logging.INFO, message) + log_metrics(logging.INFO, message) def get_torch_profiler(torch_profiler_trace_path: str) -> torch.profiler.profile: diff --git a/dolomite_engine/utils/__init__.py b/dolomite_engine/utils/__init__.py index 2d36c9c5..5152d2d5 100644 --- a/dolomite_engine/utils/__init__.py +++ b/dolomite_engine/utils/__init__.py @@ -4,8 +4,9 @@ import torch.distributed from .hf_hub import download_repo -from .logger import log_rank_0, print_rank_0, print_ranks_all, set_logger +from .logger import log_metrics, log_rank_0, print_rank_0, print_ranks_all, set_logger from .loss_dict import MetricsTrackingDict +from .miscellaneous import divide_if_divisible from .mixed_precision import normalize_dtype_string, string_to_torch_dtype, torch_dtype_to_string from .packages import ( is_apex_available, @@ -19,7 +20,7 @@ is_triton_available, log_environment, ) -from .parallel import ProcessGroupManager, run_rank_n +from .parallel import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank, run_rank_n from .pydantic import BaseArgs from .safetensors import SafeTensorsWeightsManager from .tracking import ExperimentsTracker, ProgressBar @@ -28,7 +29,8 @@ def init_distributed( - tensor_parallel_size: int, + tensor_parallel_world_size: int, + pipeline_parallel_world_size: int, data_parallel_size: int, data_parallel_replication_world_size: int, data_parallel_sharding_world_size: int, @@ -39,7 +41,8 @@ def init_distributed( """intialize distributed Args: - tensor_parallel_size (int): tensor parallel size + tensor_parallel_world_size (int): tensor parallel size + pipeline_parallel_world_size (int): pipeline parallel size data_parallel_size (int): data parallel size data_parallel_replication_world_size (int): data parallel replication world size data_parallel_sharding_world_size (int): data parallel sharding world size @@ -49,7 +52,8 @@ def init_distributed( """ process_group_manager = ProcessGroupManager( - tensor_parallel_size=tensor_parallel_size, + tensor_parallel_world_size=tensor_parallel_world_size, + pipeline_parallel_world_size=pipeline_parallel_world_size, data_parallel_size=data_parallel_size, data_parallel_replication_world_size=data_parallel_replication_world_size, data_parallel_sharding_world_size=data_parallel_sharding_world_size, diff --git a/dolomite_engine/utils/logger.py b/dolomite_engine/utils/logger.py index bb3057ce..e8e84b10 100644 --- a/dolomite_engine/utils/logger.py +++ b/dolomite_engine/utils/logger.py @@ -1,7 +1,7 @@ import logging from warnings import warn -from .parallel import ProcessGroupManager, run_rank_n +from .parallel import ProcessGroupManager, is_tracking_rank, run_rank_n _LOGGER: logging.Logger = None @@ -22,7 +22,7 @@ def set_logger(level: int = logging.INFO, colored_log: bool = False) -> None: logging.basicConfig(level=level, handlers=[stream], format="%(asctime)s - [%(levelname)-8s] ▶ %(message)s") global _LOGGER - _LOGGER = run_rank_n(logging.getLogger)() + _LOGGER = logging.getLogger() def get_logger() -> logging.Logger: @@ -32,10 +32,21 @@ def get_logger() -> logging.Logger: @run_rank_n def log_rank_0(level: int, msg: str) -> None: logger = get_logger() - if logger is not None: + + if logger is None: + set_logger() + log_rank_0(logging.WARN, "logger is not initialized yet, initializing now") + else: logger.log(level=level, msg=msg, stacklevel=3) +def log_metrics(level: int, msg: str) -> None: + if not is_tracking_rank(): + return + + get_logger().log(level=level, msg=msg, stacklevel=3) + + @run_rank_n def print_rank_0(*args, **kwargs) -> None: """print on a single process""" diff --git a/dolomite_engine/utils/miscellaneous.py b/dolomite_engine/utils/miscellaneous.py new file mode 100644 index 00000000..12568c02 --- /dev/null +++ b/dolomite_engine/utils/miscellaneous.py @@ -0,0 +1,14 @@ +def divide_if_divisible(dividend: int, divisor: int, msg: str) -> int: + """divide if divisible else raise an error + + Args: + dividend (int): dividend + divisor (int): divisor + msg (str): error message + + Returns: + int: result + """ + + assert dividend % divisor == 0, msg + return dividend // divisor diff --git a/dolomite_engine/utils/parallel.py b/dolomite_engine/utils/parallel.py index f37fa96a..472034b6 100644 --- a/dolomite_engine/utils/parallel.py +++ b/dolomite_engine/utils/parallel.py @@ -9,6 +9,8 @@ from torch.distributed._symmetric_memory import enable_symm_mem_for_group from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from .miscellaneous import divide_if_divisible + # general _MESH: DeviceMesh | None = None @@ -23,6 +25,12 @@ _TENSOR_PARALLEL_WORLD_SIZE: int | None = None _TENSOR_PARALLEL_FIRST_RANK: int | None = None +# pipeline parallel +_PIPELINE_PARALLEL_MESH: DeviceMesh | None = None +_PIPELINE_PARALLEL_GROUP: ProcessGroup | None = None +_PIPELINE_PARALLEL_RANK: int | None = None +_PIPELINE_PARALLEL_WORLD_SIZE: int | None = None + # data parallel _DATA_PARALLEL_MESH: DeviceMesh | None = None _DATA_PARALLEL_GROUP: ProcessGroup | None = None @@ -33,7 +41,8 @@ class ProcessGroupManager: def __init__( self, - tensor_parallel_size: int = 1, + tensor_parallel_world_size: int = 1, + pipeline_parallel_world_size: int = 1, data_parallel_size: int | None = None, data_parallel_replication_world_size: int | None = None, data_parallel_sharding_world_size: int | None = None, @@ -53,9 +62,9 @@ def __init__( total_gpus = int(os.getenv("WORLD_SIZE", 1)) if data_parallel_size is None: - data_parallel_size = total_gpus // tensor_parallel_size + data_parallel_size = total_gpus // (tensor_parallel_world_size * pipeline_parallel_world_size) - assert tensor_parallel_size * data_parallel_size == total_gpus + assert tensor_parallel_world_size * pipeline_parallel_world_size * data_parallel_size == total_gpus if zero_stage == 0: assert data_parallel_sharding_world_size is None or data_parallel_sharding_world_size == 1 @@ -77,8 +86,13 @@ def __init__( _MESH = init_device_mesh( "cuda", - (data_parallel_replication_world_size, data_parallel_sharding_world_size, tensor_parallel_size), - mesh_dim_names=("ddp", "fsdp", "tp"), + ( + pipeline_parallel_world_size, + data_parallel_replication_world_size, + data_parallel_sharding_world_size, + tensor_parallel_world_size, + ), + mesh_dim_names=("pp", "ddp", "fsdp", "tp"), ) local_rank = int(os.getenv("LOCAL_RANK", 0)) @@ -93,7 +107,7 @@ def is_initialized() -> bool: return torch.distributed.is_initialized() @staticmethod - def get_mesh() -> int: + def get_mesh() -> DeviceMesh: global _MESH return _MESH @@ -127,8 +141,7 @@ def get_tensor_parallel_mesh() -> DeviceMesh: global _TENSOR_PARALLEL_MESH if _TENSOR_PARALLEL_MESH is None: - global _MESH - _TENSOR_PARALLEL_MESH = _MESH["tp"] + _TENSOR_PARALLEL_MESH = ProcessGroupManager.get_mesh()["tp"] return _TENSOR_PARALLEL_MESH @staticmethod @@ -205,14 +218,74 @@ def set_dummy_tensor_parallel_first_rank(rank: int): def is_tensor_parallel_enabled() -> bool: return ProcessGroupManager.get_tensor_parallel_world_size() > 1 + @staticmethod + def is_tensor_parallel_first_rank() -> bool: + return ProcessGroupManager.get_tensor_parallel_rank() == 0 + + # pipeline parallel + @staticmethod + def get_pipeline_parallel_mesh() -> DeviceMesh: + global _PIPELINE_PARALLEL_MESH + + if _PIPELINE_PARALLEL_MESH is None: + _PIPELINE_PARALLEL_MESH = ProcessGroupManager.get_mesh()["pp"] + return _PIPELINE_PARALLEL_MESH + + @staticmethod + def get_pipeline_parallel_group() -> ProcessGroup: + global _PIPELINE_PARALLEL_GROUP + + if _PIPELINE_PARALLEL_GROUP is None: + _PIPELINE_PARALLEL_GROUP = ProcessGroupManager.get_pipeline_parallel_mesh().get_group() + return _PIPELINE_PARALLEL_GROUP + + @staticmethod + def get_pipeline_parallel_rank() -> int: + global _PIPELINE_PARALLEL_RANK + + if _PIPELINE_PARALLEL_RANK is None: + _PIPELINE_PARALLEL_RANK = ProcessGroupManager.get_pipeline_parallel_mesh().get_local_rank() + return _PIPELINE_PARALLEL_RANK + + @contextmanager + @staticmethod + def set_dummy_pipeline_parallel_rank(rank: int): + global _PIPELINE_PARALLEL_RANK + + original_rank = _PIPELINE_PARALLEL_RANK + _PIPELINE_PARALLEL_RANK = rank + + yield + + _PIPELINE_PARALLEL_RANK = original_rank + + @staticmethod + def get_pipeline_parallel_world_size() -> int: + global _PIPELINE_PARALLEL_WORLD_SIZE + + if _PIPELINE_PARALLEL_WORLD_SIZE is None: + _PIPELINE_PARALLEL_WORLD_SIZE = ProcessGroupManager.get_pipeline_parallel_mesh().size() + return _PIPELINE_PARALLEL_WORLD_SIZE + + @contextmanager + @staticmethod + def set_dummy_pipeline_parallel_world_size(world_size: int): + global _PIPELINE_PARALLEL_WORLD_SIZE + + original_world_size = _PIPELINE_PARALLEL_WORLD_SIZE + _PIPELINE_PARALLEL_WORLD_SIZE = world_size + + yield + + _PIPELINE_PARALLEL_WORLD_SIZE = original_world_size + # data parallel @staticmethod def get_data_parallel_mesh() -> DeviceMesh: global _DATA_PARALLEL_MESH if _DATA_PARALLEL_MESH is None: - global _MESH - _DATA_PARALLEL_MESH = _MESH["ddp", "fsdp"] + _DATA_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp"] return _DATA_PARALLEL_MESH @staticmethod @@ -308,3 +381,25 @@ def func_rank_other(*args, **kwargs): wrapped_func = func_rank_other return wrapped_func + + +def is_tracking_rank() -> bool: + return ( + ProcessGroupManager.get_data_parallel_rank() == 0 + and ProcessGroupManager.is_tensor_parallel_first_rank() + and ProcessGroupManager.get_pipeline_parallel_rank() + == ProcessGroupManager.get_pipeline_parallel_world_size() - 1 + ) + + +def get_pipeline_stage_ids_on_current_rank(num_pipeline_stages: int) -> int: + pp_rank = ProcessGroupManager.get_pipeline_parallel_rank() + pp_world_size = ProcessGroupManager.get_pipeline_parallel_world_size() + + num_pipeline_stages_per_rank = divide_if_divisible( + num_pipeline_stages, + pp_world_size, + "num_pipeline_stages should be divisible by pipeline_parallel_world_size", + ) + + return tuple(pp_rank + i * pp_world_size for i in range(num_pipeline_stages_per_rank)) diff --git a/dolomite_engine/utils/tracking.py b/dolomite_engine/utils/tracking.py index 5fba003a..6444ce96 100644 --- a/dolomite_engine/utils/tracking.py +++ b/dolomite_engine/utils/tracking.py @@ -2,7 +2,7 @@ from ..enums import ExperimentsTrackerName from .packages import is_aim_available, is_wandb_available -from .parallel import run_rank_n +from .parallel import is_tracking_rank from .pydantic import BaseArgs @@ -17,10 +17,13 @@ class ProgressBar: """progress bar for training or validation""" def __init__(self, start: int, end: int, desc: str | None = None) -> None: - self.progress_bar: tqdm = run_rank_n(tqdm)(total=end, desc=desc) + self.is_tracking_rank = is_tracking_rank() + if not self.is_tracking_rank: + return + + self.progress_bar = tqdm(total=end, desc=desc) self.update(start) - @run_rank_n def update(self, n: int = 1) -> None: """updates progress bar @@ -28,12 +31,17 @@ def update(self, n: int = 1) -> None: n (int, optional): Number of steps to update the progress bar with. Defaults to 1. """ + if not self.is_tracking_rank: + return + self.progress_bar.update(n=n) - @run_rank_n def track(self, **loss_kwargs) -> None: """track specific metrics in progress bar""" + if not self.is_tracking_rank: + return + # for key in loss_kwargs: # loss_kwargs[key] = "{0:.5f}".format(loss_kwargs[key]) self.progress_bar.set_postfix(**loss_kwargs) @@ -49,27 +57,30 @@ def __init__( wandb_args: BaseArgs, checkpoint_metadata: dict, ) -> None: + self.is_tracking_rank = is_tracking_rank() + if not self.is_tracking_rank: + return + self.experiments_tracker_name = experiments_tracker_name self.tracking_enabled = experiments_tracker_name is not None if experiments_tracker_name == ExperimentsTrackerName.aim: kwargs = aim_args.to_dict() if checkpoint_metadata is None else checkpoint_metadata - self.run: AimRun = run_rank_n(AimRun)(**kwargs) + self.run = AimRun(**kwargs) elif experiments_tracker_name == ExperimentsTrackerName.wandb: kwargs = wandb_args.to_dict() if checkpoint_metadata is None else checkpoint_metadata resume = None if checkpoint_metadata is None else "auto" - run_rank_n(wandb.init)(resume=resume, **kwargs) + wandb.init(resume=resume, **kwargs) # this is for a custom step, we can't use the wandb step # since it doesn't allow time travel to the past - run_rank_n(wandb.define_metric)("iteration", hidden=True) - run_rank_n(wandb.define_metric)("train/*", step_metric="iteration", step_sync=True) - run_rank_n(wandb.define_metric)("val/*", step_metric="iteration", step_sync=True) + wandb.define_metric("iteration", hidden=True) + wandb.define_metric("train/*", step_metric="iteration", step_sync=True) + wandb.define_metric("val/*", step_metric="iteration", step_sync=True) elif experiments_tracker_name is not None: raise ValueError(f"unexpected experiments_tracker ({experiments_tracker_name})") - @run_rank_n def log_args(self, args: BaseArgs) -> None: """log args @@ -77,6 +88,9 @@ def log_args(self, args: BaseArgs) -> None: args (BaseArgs): pydantic object """ + if not self.is_tracking_rank: + return + if self.tracking_enabled: args: dict = args.to_dict() @@ -91,7 +105,6 @@ def log_args(self, args: BaseArgs) -> None: else: raise ValueError(f"unexpected experiments_tracker ({self.experiments_tracker_name})") - @run_rank_n def track(self, values: dict, step: int | None = None, context: str | None = None) -> None: """main tracking method @@ -101,6 +114,9 @@ def track(self, values: dict, step: int | None = None, context: str | None = Non context (str, optional): context for tracking. Defaults to None. """ + if not self.is_tracking_rank: + return + if self.tracking_enabled: if self.experiments_tracker_name == ExperimentsTrackerName.aim: if context is not None: @@ -119,8 +135,10 @@ def track(self, values: dict, step: int | None = None, context: str | None = Non else: raise ValueError(f"unexpected experiments_tracker ({self.experiments_tracker_name})") - @run_rank_n def finish(self) -> None: + if not self.is_tracking_rank: + return + if self.tracking_enabled: if self.experiments_tracker_name == ExperimentsTrackerName.aim: self.run.close() @@ -129,8 +147,10 @@ def finish(self) -> None: else: raise ValueError(f"unexpected experiments_tracker ({self.experiments_tracker_name})") - @run_rank_n def state_dict(self) -> dict: + if not self.is_tracking_rank: + return + state_dict = {} if self.tracking_enabled: if self.experiments_tracker_name == ExperimentsTrackerName.aim: diff --git a/tests/hf_models/multi_gpu/dcp/dcp.py b/tests/hf_models/multi_gpu/dcp/dcp.py index a732e076..06c9769b 100644 --- a/tests/hf_models/multi_gpu/dcp/dcp.py +++ b/tests/hf_models/multi_gpu/dcp/dcp.py @@ -6,10 +6,10 @@ from dolomite_engine.arguments import TrainingArgs, UnshardingArgs from dolomite_engine.checkpointing import load_checkpoint_for_inference, save_checkpoint -from dolomite_engine.distributed import wrap_model_for_distributed_training +from dolomite_engine.distributed import wrap_model_container_for_distributed_training from dolomite_engine.enums import Mode from dolomite_engine.hf_models import AttentionHeadType -from dolomite_engine.model_wrapper import get_model +from dolomite_engine.model_wrapper import get_model_container from dolomite_engine.utils import ProcessGroupManager, load_yaml @@ -42,14 +42,9 @@ # activation function train_config.model_args.pretrained_config["activation_function"] = args.activation_function -tp_world_size = train_config.distributed_args.tensor_parallel_size -dp_world_size = int(os.getenv("WORLD_SIZE")) // tp_world_size - ProcessGroupManager( - tensor_parallel_size=tp_world_size, - data_parallel_size=dp_world_size, - data_parallel_replication_world_size=args.data_parallel_replication_world_size, - data_parallel_sharding_world_size=args.data_parallel_sharding_world_size, + tensor_parallel_world_size=train_config.distributed_args.tensor_parallel_world_size, + pipeline_parallel_world_size=train_config.distributed_args.pipeline_parallel_world_size, ) global_rank = ProcessGroupManager.get_global_rank() @@ -58,9 +53,16 @@ with ( ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), ProcessGroupManager.set_dummy_tensor_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), ): - model = get_model(train_config, Mode.training) - model.save_pretrained(os.path.join(args.tmp_path, "single_rank")) + original_num_stages = train_config.distributed_args.num_pipeline_stages + train_config.distributed_args.num_pipeline_stages = 1 + + model_container = get_model_container(train_config, Mode.training) + model_container[0].save_pretrained(os.path.join(args.tmp_path, "single_rank")) + + train_config.distributed_args.num_pipeline_stages = original_num_stages torch.distributed.barrier() @@ -75,14 +77,14 @@ unshard_config.load_args.iteration = iteration unshard_config.unsharded_path = os.path.join(args.tmp_path, "unsharded_path") -model_tp = get_model(train_config, Mode.training) -model_tp = wrap_model_for_distributed_training(train_config, model_tp) +parallel_model_container = get_model_container(train_config, Mode.training) +parallel_model_container, _ = wrap_model_container_for_distributed_training(train_config, parallel_model_container) save_checkpoint( train_config, - model=model_tp, - optimizer=None, - lr_scheduler=None, + model_container=parallel_model_container, + optimizer_container=None, + lr_scheduler_container=None, train_dataloader=None, experiments_tracker=None, iteration=iteration, @@ -94,7 +96,7 @@ _, _, consolidated_state_dict = load_checkpoint_for_inference(unshard_config, mode=Mode.unsharding, use_meta=False) if global_rank == 0: - original_state_dict = model.state_dict() + original_state_dict = model_container[0].state_dict() assert consolidated_state_dict.keys() == original_state_dict.keys() for key in original_state_dict: diff --git a/tests/hf_models/multi_gpu/dcp/train.yml b/tests/hf_models/multi_gpu/dcp/train.yml index 7445eb2a..bfa7c4d0 100644 --- a/tests/hf_models/multi_gpu/dcp/train.yml +++ b/tests/hf_models/multi_gpu/dcp/train.yml @@ -35,7 +35,7 @@ model_args: attention_head_type: mha scale_attn_weights: true vocab_size: 50304 - tie_word_embeddings: true + tie_word_embeddings: false bos_token_id: 0 eos_token_id: 0 pad_token_id: 0 @@ -77,4 +77,7 @@ mixed_precision_args: distributed_args: fsdp_algorithm: 2 stage: 3 - tensor_parallel_size: 2 + tensor_parallel_world_size: 2 + pipeline_parallel_world_size: 2 + num_pipeline_stages: 4 + pipeline_parallel_schedule: 1F1B diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 14d35e27..4e4e402a 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -5,14 +5,7 @@ import torch.distributed from transformers import set_seed -from dolomite_engine.hf_models import ( - AttentionHeadType, - GPTDolomiteConfig, - GPTDolomiteForCausalLM_TP, - MoEDolomiteConfig, - MoEDolomiteForCausalLM_TP, - get_tensor_parallel_class, -) +from dolomite_engine.hf_models import AttentionHeadType, GPTDolomiteConfig, MoEDolomiteConfig, get_model_parallel_class from dolomite_engine.utils import ProcessGroupManager, SafeTensorsWeightsManager, string_to_torch_dtype from ...test_common import TestCommons @@ -32,7 +25,7 @@ set_seed(42) -ProcessGroupManager(tensor_parallel_size=int(os.getenv("WORLD_SIZE"))) +ProcessGroupManager(tensor_parallel_world_size=int(os.getenv("WORLD_SIZE"))) torch_dtype = string_to_torch_dtype(args.torch_dtype) @@ -83,7 +76,7 @@ with torch.device("meta"): # try sharding vocab matrices if really struggling for memory - model_tp = get_tensor_parallel_class(args.model_type)._from_config( + model_tp = get_model_parallel_class(args.model_type)._from_config( config, tensor_parallel_word_embeddings=args.tensor_parallel_word_embeddings, attn_implementation=args.attention_implementation, diff --git a/tests/hf_models/multi_gpu/unsharding/unsharding.py b/tests/hf_models/multi_gpu/unsharding/unsharding.py index 1bd605be..1898b60a 100644 --- a/tests/hf_models/multi_gpu/unsharding/unsharding.py +++ b/tests/hf_models/multi_gpu/unsharding/unsharding.py @@ -10,7 +10,7 @@ GPTDolomiteConfig, MoEDolomiteConfig, fix_unsharded_state_dict, - get_tensor_parallel_class, + get_model_parallel_class, unshard_tensor_parallel_state_dicts, ) from dolomite_engine.utils import ProcessGroupManager @@ -27,9 +27,9 @@ args = parser.parse_args() -ProcessGroupManager(tensor_parallel_size=int(os.getenv("WORLD_SIZE"))) +ProcessGroupManager(tensor_parallel_world_size=int(os.getenv("WORLD_SIZE"))) -tp_rank = ProcessGroupManager.get_tensor_parallel_rank() +is_tp_first_rank = ProcessGroupManager.is_tensor_parallel_first_rank() num_key_value_heads = None if AttentionHeadType(args.attention_head_type) == AttentionHeadType.gqa: @@ -62,13 +62,13 @@ kwargs["moe_implementation"] = "scattermoe" -if tp_rank == 0: +if is_tp_first_rank: model = TestCommons.from_config(None, config) model.save_pretrained(args.tmp_path, safe_serialization=True) torch.distributed.barrier() -model_tp = get_tensor_parallel_class(args.model_type).from_pretrained( +model_tp = get_model_parallel_class(args.model_type).from_pretrained( args.tmp_path, tensor_parallel_word_embeddings=args.tensor_parallel_word_embeddings, **kwargs ) @@ -109,7 +109,7 @@ def run_check(fix: bool): torch.distributed.barrier() - if tp_rank == 0: + if is_tp_first_rank: original_state_dict = model.state_dict() assert tp_state_dict_unsharded.keys() == original_state_dict.keys() diff --git a/tests/training/params_group/params_group_test.py b/tests/training/params_group/params_group_test.py index b4f7dfe6..1c1e9d19 100644 --- a/tests/training/params_group/params_group_test.py +++ b/tests/training/params_group/params_group_test.py @@ -5,7 +5,7 @@ from parameterized import parameterized from dolomite_engine.enums import Mode -from dolomite_engine.model_wrapper import get_model +from dolomite_engine.model_wrapper import get_model_container from dolomite_engine.optimization.params_group import get_mup_group_with_names, get_normal_group_with_names from dolomite_engine.utils import ProcessGroupManager @@ -34,8 +34,10 @@ def test_mup_group(self, config_filename: str, expected_groups_filename: str) -> torch.device("meta"), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), ProcessGroupManager.set_dummy_tensor_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), ): - model = get_model(args, Mode.training) + model_container = get_model_container(args, Mode.training) except RuntimeError: self.skipTest("skipping rnn_dolomite test since causal-conv1d is not installed") else: @@ -43,10 +45,12 @@ def test_mup_group(self, config_filename: str, expected_groups_filename: str) -> torch.device("meta"), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), ProcessGroupManager.set_dummy_tensor_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), ): - model = get_model(args, Mode.training) + model_container = get_model_container(args, Mode.training) - _, names = get_mup_group_with_names(model, args.optimizer_args.class_args) + _, names = get_mup_group_with_names(model_container[0], args.optimizer_args.class_args) expected_group = json.load( open(os.path.join(os.path.dirname(__file__), "groups", expected_groups_filename), "r") @@ -74,8 +78,10 @@ def test_normal_group(self, config_filename: str, expected_groups_filename: str) torch.device("meta"), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), ProcessGroupManager.set_dummy_tensor_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), ): - model = get_model(args, Mode.training) + model_container = get_model_container(args, Mode.training) except RuntimeError: self.skipTest("skipping rnn_dolomite test since causal-conv1d is not installed") else: @@ -83,10 +89,12 @@ def test_normal_group(self, config_filename: str, expected_groups_filename: str) torch.device("meta"), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), ProcessGroupManager.set_dummy_tensor_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), ): - model = get_model(args, Mode.training) + model_container = get_model_container(args, Mode.training) - _, names = get_normal_group_with_names(model, args.optimizer_args.class_args) + _, names = get_normal_group_with_names(model_container[0], args.optimizer_args.class_args) expected_group = json.load( open(os.path.join(os.path.dirname(__file__), "groups", expected_groups_filename), "r") diff --git a/tests/training/params_group/training_configs/gpt_dolomite_config.yml b/tests/training/params_group/training_configs/gpt_dolomite_config.yml index 1d9dea79..b793a49f 100644 --- a/tests/training/params_group/training_configs/gpt_dolomite_config.yml +++ b/tests/training/params_group/training_configs/gpt_dolomite_config.yml @@ -82,4 +82,4 @@ mixed_precision_args: distributed_args: fsdp_algorithm: 2 stage: 3 - tensor_parallel_size: 2 + tensor_parallel_world_size: 2 diff --git a/tests/training/params_group/training_configs/moe_dolomite_config.yml b/tests/training/params_group/training_configs/moe_dolomite_config.yml index a3d959ad..677b9df2 100644 --- a/tests/training/params_group/training_configs/moe_dolomite_config.yml +++ b/tests/training/params_group/training_configs/moe_dolomite_config.yml @@ -83,4 +83,4 @@ mixed_precision_args: distributed_args: fsdp_algorithm: 2 stage: 3 - tensor_parallel_size: 2 + tensor_parallel_world_size: 2 diff --git a/tests/training/params_group/training_configs/rnn_dolomite_config.yml b/tests/training/params_group/training_configs/rnn_dolomite_config.yml index bacb6ad6..29d8314b 100644 --- a/tests/training/params_group/training_configs/rnn_dolomite_config.yml +++ b/tests/training/params_group/training_configs/rnn_dolomite_config.yml @@ -89,4 +89,4 @@ mixed_precision_args: distributed_args: fsdp_algorithm: 2 stage: 3 - tensor_parallel_size: 2 + tensor_parallel_world_size: 2 diff --git a/tools/tensor_parallel_inference.py b/tools/tensor_parallel_inference.py index 27a57ab3..07e5af56 100644 --- a/tools/tensor_parallel_inference.py +++ b/tools/tensor_parallel_inference.py @@ -14,7 +14,7 @@ torch.cuda.set_device(local_rank) -ProcessGroupManager(tensor_parallel_size=8) +ProcessGroupManager(tensor_parallel_world_size=8) model_name = "save/"