From 0fe41ddcd80ed8f90011a553a8cf9f6850df7dae Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 13 Dec 2024 20:54:14 +0530 Subject: [PATCH] feat: add tests for cli usage of TP and plugin Signed-off-by: Mehant Kammakomati --- src/accelerate/accelerator.py | 12 +++-- src/accelerate/commands/launch.py | 2 +- src/accelerate/data_loader.py | 18 ++++--- src/accelerate/test_utils/testing.py | 7 +++ src/accelerate/utils/dataclasses.py | 4 ++ tests/tp/test_tp.py | 78 ++++++++++++++++++++++++++++ 6 files changed, 108 insertions(+), 13 deletions(-) create mode 100644 tests/tp/test_tp.py diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index c15dd1624c7..796c22dfc63 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -108,7 +108,7 @@ save_fsdp_optimizer, wait_for_everyone, ) -from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME, BETA_TP_AVAILABLE_PYTORCH_VERSION +from .utils.constants import BETA_TP_AVAILABLE_PYTORCH_VERSION, FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module @@ -349,7 +349,9 @@ def __init__( if not is_torch_version(">=", FSDP_PYTORCH_VERSION): raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") - if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(torch_tp_plugin, TorchTensorParallelPlugin): + if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance( + torch_tp_plugin, TorchTensorParallelPlugin + ): if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION): raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}") @@ -363,12 +365,14 @@ def __init__( os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided if torch_tp_plugin is None: - torch_tp_plugin = (TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None) + torch_tp_plugin = ( + TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None + ) else: if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin): raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.") os.environ["ACCELERATE_USE_TP"] = "true" - + if megatron_lm_plugin is None: # init from env variables megatron_lm_plugin = ( MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 61a5b99c30a..da5eb1f67c8 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -595,7 +595,7 @@ def launch_command_parser(subparsers=None): type=str, help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).", ) - + # tp args tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.") tp_args.add_argument( diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 0e0d85c54fd..7ddf5834b2a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -740,11 +740,11 @@ def __init__( self.iteration = 0 # if a device mesh is provided extract each dimension (dp, fsdp, tp) - # device mesh may hold any number of dimensions, however, + # device mesh may hold any number of dimensions, however, # below code is for targetted support for dp, fsdp and tp - - # device mesh will be used only if there is tp involved - # or any multi-dimensional parallelism involving tp + + # device mesh will be used only if there is tp involved + # or any multi-dimensional parallelism involving tp # (dp, tp) (fsdp, tp) (dp, fsdp, tp) # otherwise the default behavour not using device mesh should be sufficient # since multi dimensional parallelism devoid of tp would anyway need @@ -777,8 +777,10 @@ def _fetch_batches(self, iterator): if self.split_batches: # One batch of the main iterator is dispatched and split. if self.submesh_tp: - logger.warning("Use of split_batches for TP would need the dataloader to produce duplicate batches," - "otherwise, use dispatch_batches=True instead.") + logger.warning( + "Use of split_batches for TP would need the dataloader to produce duplicate batches," + "otherwise, use dispatch_batches=True instead." + ) self._update_state_dict() batch = next(iterator) else: @@ -1078,7 +1080,7 @@ def prepare_data_loader( state = PartialState() if num_processes is None: num_processes = state.num_processes - + # when device mesh is used, specifically with TP # then there is need to update process_index and num_processes # to bring in the effect of generating same batch across TP ranks @@ -1098,7 +1100,7 @@ def prepare_data_loader( submesh_dp_size = torch_device_mesh["dp"].size() if "fsdp" in torch_device_mesh.mesh_dim_names: submesh_fsdp_size = torch_device_mesh["fsdp"].size() - num_processes = (submesh_fsdp_size * submesh_dp_size) + num_processes = submesh_fsdp_size * submesh_dp_size if process_index is None: process_index = state.process_index if torch_device_mesh: diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 5b9305c5c9b..cde84c1b752 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -339,6 +339,13 @@ def require_fsdp(test_case): return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case) +def require_tp(test_case): + """ + Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed + """ + return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case) + + def require_torch_min_version(test_case=None, version=None): """ Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 0869de8eb0f..028418ea926 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1827,6 +1827,10 @@ class TorchTensorParallelPlugin: torch_device_mesh: torch.distributed.DeviceMesh = field(default=None) def __post_init__(self): + self.tp_size = self.tp_size if os.environ.get("TP_SIZE", 1) == 1 else os.environ.get("TP_SIZE", 1) + if self.tp_size == 1: + raise ValueError("Provide TP degree > 1.") + if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION): raise ValueError( f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel." diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py new file mode 100644 index 00000000000..d39f635fc18 --- /dev/null +++ b/tests/tp/test_tp.py @@ -0,0 +1,78 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.testing_utils import mockenv_context +from transformers.trainer_utils import set_seed + +from accelerate.test_utils.testing import ( + AccelerateTestCase, + TempDirTestCase, + execute_subprocess_async, + get_launch_command, + path_in_accelerate_package, + require_multi_device, + require_non_cpu, + require_non_torch_xla, + require_tp, + slow, +) +from accelerate.utils import patch_environment +from accelerate.utils.dataclasses import TorchTensorParallelPlugin + + +set_seed(42) + + +@require_tp +@require_non_cpu +@require_non_torch_xla +class TPPluginIntegration(AccelerateTestCase): + def setUp(self): + super().setUp() + + self.dist_env = dict( + MASTER_ADDR="localhost", + MASTER_PORT="10999", + RANK="0", + LOCAL_RANK="0", + WORLD_SIZE="1", + ) + + self.tp_env = dict(ACCELERATE_USE_TP="true", TP_SIZE=2, **self.dist_env) + + def test_device_mesh_init(self): + with mockenv_context(**self.tp_env): + tp_plugin = TorchTensorParallelPlugin() + assert tp_plugin.torch_device_mesh["tp"].size() == self.tp_env["TP_SIZE"] + + +@require_non_torch_xla +@require_tp +@require_multi_device +@slow +class TPIntegrationTest(TempDirTestCase): + test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps") + + def setUp(self): + super().setUp() + self.test_tp_size = 2 + + def test_working_of_tp(self): + self.test_file_path = self.test_scripts_folder / "test_performance.py" + cmd = get_launch_command( + num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size + ) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd)