diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index ab949c42e43..796c22dfc63 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -67,6 +67,7 @@ ProjectConfiguration, RNGType, TorchDynamoPlugin, + TorchTensorParallelPlugin, apply_fp8_autowrap, check_os_kernel, clean_state_dict_for_safetensors, @@ -107,7 +108,7 @@ save_fsdp_optimizer, wait_for_everyone, ) -from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME +from .utils.constants import BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module @@ -188,6 +189,9 @@ class Accelerator: fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*): Tweak your FSDP related args using this argument. This argument is optional and can be configured directly using *accelerate config* + torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*): + Tweak your torch tensor parallel. This argument is optional and can be configured directly using + *accelerate config* megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*): Tweak your MegatronLM related args using this argument. This argument is optional and can be configured directly using *accelerate config* @@ -254,6 +258,7 @@ def __init__( dataloader_config: DataLoaderConfiguration | None = None, deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None, fsdp_plugin: FullyShardedDataParallelPlugin | None = None, + torch_tp_plugin: TorchTensorParallelPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None, rng_types: list[str | RNGType] | None = None, log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, @@ -344,6 +349,12 @@ def __init__( if not is_torch_version(">=", FSDP_PYTORCH_VERSION): raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") + if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance( + torch_tp_plugin, TorchTensorParallelPlugin + ): + if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION): + raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}") + if fsdp_plugin is None: # init from env variables fsdp_plugin = ( FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None @@ -353,6 +364,15 @@ def __init__( raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.") os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided + if torch_tp_plugin is None: + torch_tp_plugin = ( + TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None + ) + else: + if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin): + raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.") + os.environ["ACCELERATE_USE_TP"] = "true" + if megatron_lm_plugin is None: # init from env variables megatron_lm_plugin = ( MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None @@ -418,6 +438,7 @@ def __init__( dynamo_plugin=dynamo_plugin, deepspeed_plugin=deepspeed_plugins, fsdp_plugin=fsdp_plugin, + torch_tp_plugin=torch_tp_plugin, megatron_lm_plugin=megatron_lm_plugin, _from_accelerator=True, **kwargs, @@ -1461,6 +1482,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ) if self.ddp_handler is not None: self.ddp_handler.register_comm_hook(model) + elif self.distributed_type == DistributedType.TP: + if not model.supports_tp_plan: + raise NotImplementedError("Provided model does not support tensor parallelism") + model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -2117,6 +2142,7 @@ def prepare_data_loader( data_seed=self.dataloader_config.data_seed, non_blocking=self.non_blocking, use_stateful_dataloader=self.use_stateful_dataloader, + torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index 0862b9c9b09..c0385e5c871 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -376,6 +376,7 @@ def get_cluster_input(): ) fsdp_config = {} + tp_config = {} if distributed_type in [ DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, @@ -475,7 +476,21 @@ def get_cluster_input(): default=False, error_message="Please enter yes or no.", ) - + if not use_fsdp: + use_tp = _ask_field( + "Do you want to use TensorParallel? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_tp: + distributed_type = DistributedType.TP + if distributed_type == DistributedType.TP: + tp_config["tp_size"] = _ask_field( + "What should be your Tensor Parallel degree? [1]: ", + int, + default=1, + ) megatron_lm_config = {} if distributed_type in [DistributedType.MULTI_GPU]: use_megatron_lm = _ask_field( @@ -808,6 +823,7 @@ def get_cluster_input(): fp8_config=fp8_config, deepspeed_config=deepspeed_config, fsdp_config=fsdp_config, + tp_config=tp_config, megatron_lm_config=megatron_lm_config, ipex_config=ipex_config, mpirun_config=mpirun_config, diff --git a/src/accelerate/commands/config/config_args.py b/src/accelerate/commands/config/config_args.py index a3991b2808d..c9e53d9b47b 100644 --- a/src/accelerate/commands/config/config_args.py +++ b/src/accelerate/commands/config/config_args.py @@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig): deepspeed_config: dict = None # args for fsdp fsdp_config: dict = None + # args for tp + tp_config: dict = None # args for megatron_lm megatron_lm_config: dict = None # args for ipex @@ -221,6 +223,8 @@ def __post_init__(self): self.deepspeed_config = {} if self.fsdp_config is None: self.fsdp_config = {} + if self.tp_config is None: + self.tp_config = {} if self.megatron_lm_config is None: self.megatron_lm_config = {} if self.ipex_config is None: diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 92e27cbfd4a..da5eb1f67c8 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -74,6 +74,7 @@ "tpu": "TPU", "use_deepspeed": "DeepSpeed Arguments", "use_fsdp": "FSDP Arguments", + "use_tp": "PyTorch TP Arguments", "use_megatron_lm": "Megatron-LM Arguments", "fp8_backend": "FP8 Arguments", } @@ -262,6 +263,12 @@ def launch_command_parser(subparsers=None): action="store_true", help="Whether to use fsdp.", ) + paradigm_args.add_argument( + "--use_tp", + default=False, + action="store_true", + help="Whether to use PyTorch TP.", + ) paradigm_args.add_argument( "--use_megatron_lm", default=False, @@ -589,6 +596,15 @@ def launch_command_parser(subparsers=None): help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).", ) + # tp args + tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.") + tp_args.add_argument( + "--tp_size", + default=1, + type=int, + help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)", + ) + # megatron_lm args megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.") megatron_lm_args.add_argument( @@ -965,9 +981,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args): def _validate_launch_command(args): # Sanity checks - if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1: + if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1: raise ValueError( - "You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time." + "You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time." ) if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2): raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.") @@ -984,6 +1000,7 @@ def _validate_launch_command(args): and not args.tpu_use_cluster and not args.use_deepspeed and not args.use_fsdp + and not args.use_tp and not args.use_megatron_lm ): args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED @@ -1001,6 +1018,7 @@ def _validate_launch_command(args): ) args.tpu = defaults.distributed_type == DistributedType.XLA args.use_fsdp = defaults.distributed_type == DistributedType.FSDP + args.use_tp = defaults.distributed_type == DistributedType.TP args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False if args.gpu_ids is None: @@ -1028,6 +1046,8 @@ def _validate_launch_command(args): if "fsdp" not in arg_to_set: arg_to_set = "fsdp_" + arg_to_set setattr(args, arg_to_set, defaults.fsdp_config[k]) + for k in defaults.tp_config: + setattr(args, k, defaults.tp_config[k]) for k in defaults.megatron_lm_config: setattr(args, k, defaults.megatron_lm_config[k]) for k in defaults.dynamo_config: @@ -1153,6 +1173,8 @@ def launch_command(args): deepspeed_launcher(args) elif args.use_fsdp and not args.cpu: multi_gpu_launcher(args) + elif args.use_tp and not args.cpu: + multi_gpu_launcher(args) elif args.use_megatron_lm and not args.cpu: multi_gpu_launcher(args) elif args.multi_gpu and not args.cpu: diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index bf3f35fb7e8..7ddf5834b2a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -528,6 +528,7 @@ def __init__( use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, + torch_device_mesh=None, **kwargs, ): super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) @@ -713,6 +714,7 @@ def __init__( _drop_last: bool = False, _non_blocking: bool = False, slice_fn=None, + torch_device_mesh=None, **kwargs, ): shuffle = False @@ -732,26 +734,68 @@ def __init__( self._drop_last = _drop_last self._non_blocking = _non_blocking self.skip_batches = skip_batches + self.torch_device_mesh = torch_device_mesh self.slice_fn = slice_tensors if slice_fn is None else slice_fn self.iteration = 0 + # if a device mesh is provided extract each dimension (dp, fsdp, tp) + # device mesh may hold any number of dimensions, however, + # below code is for targetted support for dp, fsdp and tp + + # device mesh will be used only if there is tp involved + # or any multi-dimensional parallelism involving tp + # (dp, tp) (fsdp, tp) (dp, fsdp, tp) + # otherwise the default behavour not using device mesh should be sufficient + # since multi dimensional parallelism devoid of tp would anyway need + # different batches for each process irrespective of dp or fsdp + self.submesh_tp = None + self.submesh_dp = None + self.submesh_fsdp = None + if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names: + self.submesh_tp = self.torch_device_mesh["tp"] + if "dp" in self.torch_device_mesh.mesh_dim_names: + self.submesh_dp = self.torch_device_mesh["dp"] + if "fsdp" in self.torch_device_mesh.mesh_dim_names: + self.submesh_fsdp = self.torch_device_mesh["fsdp"] + if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp): + raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode") + def _fetch_batches(self, iterator): batches, batch = None, None # On process 0, we gather the batch to dispatch. if self.state.process_index == 0: + # Procedure to support TP only is simpler + # since we want to dispatch the same batch of samples across all ranks + # this removes complexity of handling multiple tp rank groups when TP + DP + # combination is involved. + try: + # for TP case avoid using split_batches + # since it would mean that the dataloader should be spilling out + # duplicates of batches. if self.split_batches: # One batch of the main iterator is dispatched and split. + if self.submesh_tp: + logger.warning( + "Use of split_batches for TP would need the dataloader to produce duplicate batches," + "otherwise, use dispatch_batches=True instead." + ) self._update_state_dict() batch = next(iterator) else: # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] - for _ in range(self.state.num_processes): + if self.submesh_tp: + # when tp, extract single batch and then replicate self._update_state_dict() - batches.append(next(iterator)) + batch = next(iterator) + batches = [batch] * self.state.num_processes + else: + for _ in range(self.state.num_processes): + self._update_state_dict() + batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) except RuntimeError as e: @@ -942,6 +986,7 @@ def prepare_data_loader( data_seed: Optional[int] = None, non_blocking: bool = False, use_stateful_dataloader: bool = False, + torch_device_mesh: torch.distributed.DeviceMesh = None, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -1009,6 +1054,8 @@ def prepare_data_loader( "If set to true, the dataloader prepared by the Accelerator will be backed by " "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." + torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`): + PyTorch device mesh. Returns: @@ -1033,8 +1080,31 @@ def prepare_data_loader( state = PartialState() if num_processes is None: num_processes = state.num_processes + + # when device mesh is used, specifically with TP + # then there is need to update process_index and num_processes + # to bring in the effect of generating same batch across TP ranks + # and different batch across FSDP and DP ranks. + # Example: + # if device mesh is (dp,fsdp,tp) = (2, 2, 3) + # ranks would range from 0...11 + # from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3 + # processes with same ranks/ids would receive the same batch + submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + if torch_device_mesh: + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + num_processes = submesh_fsdp_size * submesh_dp_size if process_index is None: process_index = state.process_index + if torch_device_mesh: + process_index = process_index // submesh_tp_size # Sanity check if split_batches: @@ -1144,6 +1214,7 @@ def prepare_data_loader( _non_blocking=non_blocking, slice_fn=slice_fn_for_dispatch, use_stateful_dataloader=use_stateful_dataloader, + torch_device_mesh=torch_device_mesh, **kwargs, ) elif sampler_is_batch_sampler: diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 47d718704a6..8d226ac305b 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -850,6 +850,7 @@ def __init__( dynamo_plugin=None, deepspeed_plugin=None, fsdp_plugin=None, + torch_tp_plugin=None, megatron_lm_plugin=None, _from_accelerator: bool = False, **kwargs, @@ -864,6 +865,7 @@ def __init__( if not self.initialized: self.deepspeed_plugins = None self.use_ipex = None + self.torch_tp_plugin = torch_tp_plugin mixed_precision = ( parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None @@ -921,6 +923,8 @@ def __init__( self.distributed_type = DistributedType.MEGATRON_LM megatron_lm_plugin.set_mixed_precision(self._mixed_precision) self.megatron_lm_plugin = megatron_lm_plugin + if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None: + self.distributed_type = DistributedType.TP elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]: if is_ipex_available(): # check if user disables it explicitly diff --git a/src/accelerate/test_utils/scripts/external_deps/test_performance.py b/src/accelerate/test_utils/scripts/external_deps/test_performance.py index 57fb1a01884..7b6f21350fd 100644 --- a/src/accelerate/test_utils/scripts/external_deps/test_performance.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_performance.py @@ -43,6 +43,7 @@ def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: model_name (`str`, *optional*): """ tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") def tokenize_function(examples): @@ -93,6 +94,10 @@ def training_function(config, args): # Instantiate the model (we build the model here so that the seed also control new weights initialization) model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + if args.add_pad_token: + if model.config.pad_token_id is None: + model.config.pad_token_id = 0 + # Instantiate optimizer optimizer_cls = ( AdamW @@ -243,6 +248,12 @@ def main(): default=3, help="Number of train epochs.", ) + parser.add_argument( + "--add_pad_token", + type=bool, + default=False, + help="To add pad token if not exists.", + ) args = parser.parse_args() config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} training_function(config, args) diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 5b9305c5c9b..1902994a625 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -339,6 +339,13 @@ def require_fsdp(test_case): return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case) +def require_tp(test_case): + """ + Decorator marking a test that requires TP installed. These tests are skipped when TP isn't installed + """ + return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case) + + def require_torch_min_version(test_case=None, version=None): """ Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 5b8917fcd48..558cc2f4769 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -57,6 +57,7 @@ SageMakerDistributedType, TensorInformation, TorchDynamoPlugin, + TorchTensorParallelPlugin, add_model_config_to_megatron_parser, ) from .environment import ( diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index a6d7d262678..c1437a35b3c 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -46,6 +46,7 @@ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0" XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0" MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0" +BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0" STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} @@ -76,7 +77,7 @@ "master_port", ] -CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"] +CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"] TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [ "MULTI_NPU", "MULTI_MLU", diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 39e048a6039..a16205c6cc2 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -30,6 +30,7 @@ import torch from .constants import ( + BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, @@ -540,6 +541,7 @@ class DistributedType(str, enum.Enum): MULTI_XPU = "MULTI_XPU" DEEPSPEED = "DEEPSPEED" FSDP = "FSDP" + TP = "TP" XLA = "XLA" MEGATRON_LM = "MEGATRON_LM" @@ -1810,6 +1812,36 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) +@dataclass +class TorchTensorParallelPlugin: + """ + This plugin is used to enable tensor parallelism using PyTorch >= 2.0. + """ + + tp_size: int = field( + default=1, + metadata={"help": "tensor parallel size will be used in the device mesh preparation"}, + ) + + # type has to be "torch.distributed.DeviceMesh" + torch_device_mesh: torch.distributed.DeviceMesh = field(default=None) + + def __post_init__(self): + self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1")) + if self.tp_size == 1: + raise ValueError("Provide TP degree > 1.") + + if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION): + raise ValueError( + f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel." + ) + from torch.distributed.device_mesh import init_device_mesh + + mesh_dim_name = "tp" + device = "cuda" # support for other devices has to be investigated + self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,)) + + @dataclass class MegatronLMPlugin: """ diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index c6f3d60031d..68b6355912a 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -284,6 +284,10 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]: current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower() current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower() + if args.use_tp: + current_env["ACCELERATE_USE_TP"] = "true" + current_env["TP_SIZE"] = str(args.tp_size) + if args.use_megatron_lm: prefix = "MEGATRON_LM_" current_env["ACCELERATE_USE_MEGATRON_LM"] = "true" diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py new file mode 100644 index 00000000000..4d31fc60ca0 --- /dev/null +++ b/tests/tp/test_tp.py @@ -0,0 +1,61 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.trainer_utils import set_seed + +from accelerate.test_utils.testing import ( + TempDirTestCase, + execute_subprocess_async, + get_launch_command, + path_in_accelerate_package, + require_multi_device, + require_non_torch_xla, + require_tp, + slow, +) +from accelerate.utils import patch_environment + + +set_seed(42) + + +@require_non_torch_xla +@require_tp +@require_multi_device +@slow +class TPIntegrationTest(TempDirTestCase): + test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps") + + def setUp(self): + super().setUp() + self.test_tp_size = 2 + self.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + self.batch_size = 1 + + def test_working_of_tp(self): + self.test_file_path = self.test_scripts_folder / "test_performance.py" + cmd = get_launch_command( + num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size + ) + cmd.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + f"--model_name_or_path={self.model_name_or_path}", + "--add_pad_token=true", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd)