From e08c364316546709ed2c8c2ec819897cad52e382 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 13 Dec 2024 20:54:14 +0530 Subject: [PATCH] feat: add tests for cli usage of TP and plugin Signed-off-by: Mehant Kammakomati --- src/accelerate/test_utils/testing.py | 6 +++ src/accelerate/utils/dataclasses.py | 4 ++ tests/tp/test_tp.py | 77 ++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+) create mode 100644 tests/tp/test_tp.py diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 5b9305c5c9b..a051d6d8d77 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -338,6 +338,12 @@ def require_fsdp(test_case): """ return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case) +def require_tp(test_case): + """ + Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed + """ + return unittest.skipUnless(is_torch_version(">=", "2.3.0"), "test requires torch version >= 2.3.0")(test_case) + def require_torch_min_version(test_case=None, version=None): """ diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 0869de8eb0f..028418ea926 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1827,6 +1827,10 @@ class TorchTensorParallelPlugin: torch_device_mesh: torch.distributed.DeviceMesh = field(default=None) def __post_init__(self): + self.tp_size = self.tp_size if os.environ.get("TP_SIZE", 1) == 1 else os.environ.get("TP_SIZE", 1) + if self.tp_size == 1: + raise ValueError("Provide TP degree > 1.") + if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION): raise ValueError( f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel." diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py new file mode 100644 index 00000000000..84b7c418316 --- /dev/null +++ b/tests/tp/test_tp.py @@ -0,0 +1,77 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.testing_utils import mockenv_context +from transformers.trainer_utils import set_seed + + +from accelerate.test_utils.testing import ( + AccelerateTestCase, + TempDirTestCase, + execute_subprocess_async, + get_launch_command, + path_in_accelerate_package, + require_tp, + require_multi_device, + require_non_cpu, + require_non_torch_xla, + slow, +) +from accelerate.utils import patch_environment +from accelerate.utils.dataclasses import TorchTensorParallelPlugin + + +set_seed(42) + + +@require_tp +@require_non_cpu +@require_non_torch_xla +class TPPluginIntegration(AccelerateTestCase): + def setUp(self): + super().setUp() + + self.dist_env = dict( + MASTER_ADDR="localhost", + MASTER_PORT="10999", + RANK="0", + LOCAL_RANK="0", + WORLD_SIZE="1", + ) + + self.tp_env = dict(ACCELERATE_USE_TP="true", TP_SIZE=2, **self.dist_env) + + def test_device_mesh_init(self): + with mockenv_context(**self.tp_env): + tp_plugin = TorchTensorParallelPlugin() + assert tp_plugin.torch_device_mesh["tp"].size() == self.tp_env["TP_SIZE"] + + +@require_non_torch_xla +@require_tp +@require_multi_device +@slow +class TPIntegrationTest(TempDirTestCase): + test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps") + + def setUp(self): + super().setUp() + self.test_tp_size = 2 + + def test_working_of_tp(self): + self.test_file_path = self.test_scripts_folder / "test_performance.py" + cmd = get_launch_command(num_processes=self.test_tp_size, num_machines=1, machine_rank=0, use_tp=True, tp_size=self.test_tp_size) + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd)