diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index ab949c42e43..e3cea98a7f4 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -67,6 +67,7 @@ ProjectConfiguration, RNGType, TorchDynamoPlugin, + TorchTensorParallelPlugin, apply_fp8_autowrap, check_os_kernel, clean_state_dict_for_safetensors, @@ -188,6 +189,9 @@ class Accelerator: fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*): Tweak your FSDP related args using this argument. This argument is optional and can be configured directly using *accelerate config* + torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*): + Tweak your torch tensor parallel. This argument is optional and can be configured directly using + *accelerate config* megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*): Tweak your MegatronLM related args using this argument. This argument is optional and can be configured directly using *accelerate config* @@ -254,6 +258,7 @@ def __init__( dataloader_config: DataLoaderConfiguration | None = None, deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None, fsdp_plugin: FullyShardedDataParallelPlugin | None = None, + torch_tp_plugin: TorchTensorParallelPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None, rng_types: list[str | RNGType] | None = None, log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, @@ -418,6 +423,7 @@ def __init__( dynamo_plugin=dynamo_plugin, deepspeed_plugin=deepspeed_plugins, fsdp_plugin=fsdp_plugin, + torch_tp_plugin=torch_tp_plugin, megatron_lm_plugin=megatron_lm_plugin, _from_accelerator=True, **kwargs, @@ -1461,6 +1467,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ) if self.ddp_handler is not None: self.ddp_handler.register_comm_hook(model) + elif self.distributed_type == DistributedType.TP: + if not model.supports_tp_plan: + raise NotImplementedError("Provided model does not support tensor parallelism") + model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -2117,6 +2127,7 @@ def prepare_data_loader( data_seed=self.dataloader_config.data_seed, non_blocking=self.non_blocking, use_stateful_dataloader=self.use_stateful_dataloader, + torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index bf3f35fb7e8..99c759456ba 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -528,6 +528,7 @@ def __init__( use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, + torch_device_mesh=None, **kwargs, ): super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) @@ -713,6 +714,7 @@ def __init__( _drop_last: bool = False, _non_blocking: bool = False, slice_fn=None, + torch_device_mesh=None, **kwargs, ): shuffle = False @@ -732,26 +734,66 @@ def __init__( self._drop_last = _drop_last self._non_blocking = _non_blocking self.skip_batches = skip_batches + self.torch_device_mesh = torch_device_mesh self.slice_fn = slice_tensors if slice_fn is None else slice_fn self.iteration = 0 + # if a device mesh is provided extract each dimension (dp, fsdp, tp) + # device mesh may hold any number of dimensions, however, + # below code is for targetted support for dp, fsdp and tp + + # device mesh will be used only if there is tp involved + # or any multi-dimensional parallelism involving tp + # (dp, tp) (fsdp, tp) (dp, fsdp, tp) + # otherwise the default behavour not using device mesh should be sufficient + # since multi dimensional parallelism devoid of tp would anyway need + # different batches for each process irrespective of dp or fsdp + self.submesh_tp = None + self.submesh_dp = None + self.submesh_fsdp = None + if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names: + self.submesh_tp = self.torch_device_mesh["tp"] + if "dp" in self.torch_device_mesh.mesh_dim_names: + self.submesh_dp = self.torch_device_mesh["dp"] + if "fsdp" in self.torch_device_mesh.mesh_dim_names: + self.submesh_fsdp = self.torch_device_mesh["fsdp"] + if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp): + raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode") + def _fetch_batches(self, iterator): batches, batch = None, None # On process 0, we gather the batch to dispatch. if self.state.process_index == 0: + # Procedure to support TP only is simpler + # since we want to dispatch the same batch of samples across all ranks + # this removes complexity of handling multiple tp rank groups when TP + DP + # combination is involved. + try: + # for TP case avoid using split_batches + # since it would mean that the dataloader should be spilling out + # duplicates of batches. if self.split_batches: # One batch of the main iterator is dispatched and split. + if self.submesh_tp: + logger.warning("Use of split_batches for TP would need the dataloader to produce duplicate batches," + "otherwise, use dispatch_batches=True instead.") self._update_state_dict() batch = next(iterator) else: # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] - for _ in range(self.state.num_processes): + if self.submesh_tp: + # when tp, extract single batch and then replicate self._update_state_dict() - batches.append(next(iterator)) + batch = next(iterator) + batches = [batch] * self.state.num_processes + else: + for _ in range(self.state.num_processes): + self._update_state_dict() + batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) except RuntimeError as e: @@ -942,6 +984,7 @@ def prepare_data_loader( data_seed: Optional[int] = None, non_blocking: bool = False, use_stateful_dataloader: bool = False, + torch_device_mesh: torch.distributed.DeviceMesh = None, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -1009,6 +1052,8 @@ def prepare_data_loader( "If set to true, the dataloader prepared by the Accelerator will be backed by " "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." + torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`): + PyTorch device mesh. Returns: @@ -1033,8 +1078,21 @@ def prepare_data_loader( state = PartialState() if num_processes is None: num_processes = state.num_processes + submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + if torch_device_mesh: + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + num_processes = (submesh_fsdp_size * submesh_dp_size) if process_index is None: process_index = state.process_index + if torch_device_mesh: + process_index = process_index // submesh_tp_size # Sanity check if split_batches: @@ -1144,6 +1202,7 @@ def prepare_data_loader( _non_blocking=non_blocking, slice_fn=slice_fn_for_dispatch, use_stateful_dataloader=use_stateful_dataloader, + torch_device_mesh=torch_device_mesh, **kwargs, ) elif sampler_is_batch_sampler: diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 47d718704a6..8d226ac305b 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -850,6 +850,7 @@ def __init__( dynamo_plugin=None, deepspeed_plugin=None, fsdp_plugin=None, + torch_tp_plugin=None, megatron_lm_plugin=None, _from_accelerator: bool = False, **kwargs, @@ -864,6 +865,7 @@ def __init__( if not self.initialized: self.deepspeed_plugins = None self.use_ipex = None + self.torch_tp_plugin = torch_tp_plugin mixed_precision = ( parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None @@ -921,6 +923,8 @@ def __init__( self.distributed_type = DistributedType.MEGATRON_LM megatron_lm_plugin.set_mixed_precision(self._mixed_precision) self.megatron_lm_plugin = megatron_lm_plugin + if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None: + self.distributed_type = DistributedType.TP elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]: if is_ipex_available(): # check if user disables it explicitly diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 5b8917fcd48..558cc2f4769 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -57,6 +57,7 @@ SageMakerDistributedType, TensorInformation, TorchDynamoPlugin, + TorchTensorParallelPlugin, add_model_config_to_megatron_parser, ) from .environment import ( diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index a6d7d262678..c1437a35b3c 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -46,6 +46,7 @@ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0" XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0" MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0" +BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0" STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} @@ -76,7 +77,7 @@ "master_port", ] -CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"] +CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"] TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [ "MULTI_NPU", "MULTI_MLU", diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 39e048a6039..0869de8eb0f 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -30,6 +30,7 @@ import torch from .constants import ( + BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, @@ -540,6 +541,7 @@ class DistributedType(str, enum.Enum): MULTI_XPU = "MULTI_XPU" DEEPSPEED = "DEEPSPEED" FSDP = "FSDP" + TP = "TP" XLA = "XLA" MEGATRON_LM = "MEGATRON_LM" @@ -1810,6 +1812,32 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) +@dataclass +class TorchTensorParallelPlugin: + """ + This plugin is used to enable tensor parallelism using PyTorch >= 2.0. + """ + + tp_size: int = field( + default=1, + metadata={"help": "tensor parallel size will be used in the device mesh preparation"}, + ) + + # type has to be "torch.distributed.DeviceMesh" + torch_device_mesh: torch.distributed.DeviceMesh = field(default=None) + + def __post_init__(self): + if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION): + raise ValueError( + f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel." + ) + from torch.distributed.device_mesh import init_device_mesh + + mesh_dim_name = "tp" + device = "cuda" # support for other devices has to be investigated + self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,)) + + @dataclass class MegatronLMPlugin: """