diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index c15dd1624c7..796c22dfc63 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -108,7 +108,7 @@ save_fsdp_optimizer, wait_for_everyone, ) -from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME, BETA_TP_AVAILABLE_PYTORCH_VERSION +from .utils.constants import BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module @@ -349,7 +349,9 @@ def __init__( if not is_torch_version(">=", FSDP_PYTORCH_VERSION): raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") - if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(torch_tp_plugin, TorchTensorParallelPlugin): + if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance( + torch_tp_plugin, TorchTensorParallelPlugin + ): if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION): raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}") @@ -363,12 +365,14 @@ def __init__( os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided if torch_tp_plugin is None: - torch_tp_plugin = (TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None) + torch_tp_plugin = ( + TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None + ) else: if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin): raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.") os.environ["ACCELERATE_USE_TP"] = "true" - + if megatron_lm_plugin is None: # init from env variables megatron_lm_plugin = ( MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 61a5b99c30a..da5eb1f67c8 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -595,7 +595,7 @@ def launch_command_parser(subparsers=None): type=str, help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).", ) - + # tp args tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.") tp_args.add_argument( diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 0e0d85c54fd..7ddf5834b2a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -740,11 +740,11 @@ def __init__( self.iteration = 0 # if a device mesh is provided extract each dimension (dp, fsdp, tp) - # device mesh may hold any number of dimensions, however, + # device mesh may hold any number of dimensions, however, # below code is for targetted support for dp, fsdp and tp - - # device mesh will be used only if there is tp involved - # or any multi-dimensional parallelism involving tp + + # device mesh will be used only if there is tp involved + # or any multi-dimensional parallelism involving tp # (dp, tp) (fsdp, tp) (dp, fsdp, tp) # otherwise the default behavour not using device mesh should be sufficient # since multi dimensional parallelism devoid of tp would anyway need @@ -777,8 +777,10 @@ def _fetch_batches(self, iterator): if self.split_batches: # One batch of the main iterator is dispatched and split. if self.submesh_tp: - logger.warning("Use of split_batches for TP would need the dataloader to produce duplicate batches," - "otherwise, use dispatch_batches=True instead.") + logger.warning( + "Use of split_batches for TP would need the dataloader to produce duplicate batches," + "otherwise, use dispatch_batches=True instead." + ) self._update_state_dict() batch = next(iterator) else: @@ -1078,7 +1080,7 @@ def prepare_data_loader( state = PartialState() if num_processes is None: num_processes = state.num_processes - + # when device mesh is used, specifically with TP # then there is need to update process_index and num_processes # to bring in the effect of generating same batch across TP ranks @@ -1098,7 +1100,7 @@ def prepare_data_loader( submesh_dp_size = torch_device_mesh["dp"].size() if "fsdp" in torch_device_mesh.mesh_dim_names: submesh_fsdp_size = torch_device_mesh["fsdp"].size() - num_processes = (submesh_fsdp_size * submesh_dp_size) + num_processes = submesh_fsdp_size * submesh_dp_size if process_index is None: process_index = state.process_index if torch_device_mesh: diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 5b9305c5c9b..cde84c1b752 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -339,6 +339,13 @@ def require_fsdp(test_case): return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case) +def require_tp(test_case): + """ + Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed + """ + return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case) + + def require_torch_min_version(test_case=None, version=None): """ Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 0869de8eb0f..a16205c6cc2 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1827,6 +1827,10 @@ class TorchTensorParallelPlugin: torch_device_mesh: torch.distributed.DeviceMesh = field(default=None) def __post_init__(self): + self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1")) + if self.tp_size == 1: + raise ValueError("Provide TP degree > 1.") + if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION): raise ValueError( f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel." diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index 2dd6edbfb46..68b6355912a 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -286,6 +286,7 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]: if args.use_tp: current_env["ACCELERATE_USE_TP"] = "true" + current_env["TP_SIZE"] = str(args.tp_size) if args.use_megatron_lm: prefix = "MEGATRON_LM_" diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py new file mode 100644 index 00000000000..2fffd45644c --- /dev/null +++ b/tests/tp/test_tp.py @@ -0,0 +1,57 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.trainer_utils import set_seed + +from accelerate.test_utils.testing import ( + TempDirTestCase, + execute_subprocess_async, + get_launch_command, + path_in_accelerate_package, + require_multi_device, + require_non_torch_xla, + require_tp, + slow, +) +from accelerate.utils import patch_environment + + +set_seed(42) + + +@require_non_torch_xla +@require_tp +@require_multi_device +@slow +class TPIntegrationTest(TempDirTestCase): + test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps") + + def setUp(self): + super().setUp() + self.test_tp_size = 2 + + def test_working_of_tp(self): + self.test_file_path = self.test_scripts_folder / "test_performance.py" + cmd = get_launch_command( + num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size + ) + cmd.extend( + [ + self.test_file_path, + f"--output_dir={self.tmpdir}", + ] + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd)