From 16bb1d4f2e605bce04a863a2c1164c36ee29ebad Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Wed, 10 Apr 2024 14:57:38 -0700 Subject: [PATCH] Support auto planner in _shard_modules (#1839) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1839 As titled Reviewed By: henrylhtsang Differential Revision: D55423537 fbshipit-source-id: 167c92b44bda4b103caa4ac820f35b11850f91bd --- torchrec/distributed/planner/__init__.py | 5 +- torchrec/distributed/planner/planners.py | 24 ++++-- torchrec/distributed/shard.py | 46 ++++++++---- .../tests/test_infer_hetero_shardings.py | 75 ++++++++++++++++++- 4 files changed, 127 insertions(+), 23 deletions(-) diff --git a/torchrec/distributed/planner/__init__.py b/torchrec/distributed/planner/__init__.py index efd06bf02..90fed5c29 100644 --- a/torchrec/distributed/planner/__init__.py +++ b/torchrec/distributed/planner/__init__.py @@ -21,6 +21,9 @@ - automatically building and selecting an optimized sharding plan. """ -from torchrec.distributed.planner.planners import EmbeddingShardingPlanner # noqa +from torchrec.distributed.planner.planners import ( # noqa + EmbeddingShardingPlanner, + HeteroEmbeddingShardingPlanner, # noqa +) from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 725a7fbdc..02eedaf89 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -453,7 +453,9 @@ def __init__( } self._topology_groups: Dict[str, Topology] = topology_groups self._batch_size: int = batch_size if batch_size else BATCH_SIZE - self._constraints = constraints + self._constraints: Dict[str, ParameterConstraints] = ( + constraints if constraints else {} + ) # pyre-ignore self._enumerators: Dict[str, Enumerator] = ( enumerators @@ -584,17 +586,27 @@ def plan( sharders=sharders, ) + # If no device_group is assigned in the constraints, use the current device group + # TODO: Create a util function to assign device_group according to empirical rules to cuda or cpu + for sharding_option in search_space: + if sharding_option.name not in self._constraints: + self._constraints[sharding_option.name] = ParameterConstraints( + device_group=group + ) + elif not self._constraints[sharding_option.name].device_group: + self._constraints[sharding_option.name].device_group = group + # filter by device group search_space = [ - s_o - for s_o in search_space - # pyre-ignore [16] - if self._constraints[s_o.name].device_group == group + sharding_option + for sharding_option in search_space + if self._constraints[sharding_option.name].device_group == group ] if not search_space: # No shardable parameters - return ShardingPlan({}) + best_plans.append(ShardingPlan({})) + continue proposal_cache: Dict[ Tuple[int, ...], diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 9a3c22b44..370bdebc9 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Callable, Dict, List, Optional, Type, Union +from typing import Callable, cast, Dict, List, Optional, Type, Union import torch import torch.distributed as dist @@ -15,7 +15,11 @@ from torch.distributed._composable.contract import contract from torchrec.distributed.comm import get_local_size from torchrec.distributed.model_parallel import get_default_sharders -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + HeteroEmbeddingShardingPlanner, + Topology, +) from torchrec.distributed.sharding_plan import ( get_module_to_default_sharders, ParameterShardingGenerator, @@ -223,21 +227,33 @@ def _shard_modules( # noqa: C901 } if plan is None: - assert isinstance( - env, ShardingEnv - ), "Currently hybrid sharding only support use manual sharding plan" - planner = EmbeddingShardingPlanner( - topology=Topology( - local_world_size=get_local_size(env.world_size), - world_size=env.world_size, - compute_device=device.type, + if isinstance(env, dict) and len(env) > 1: # use heterogenous sharding + planner = HeteroEmbeddingShardingPlanner( + topology_groups={ + group: Topology( + local_world_size=get_local_size(cur_env.world_size), + world_size=cur_env.world_size, + compute_device=group, + ) + for group, cur_env in env.items() + } ) - ) - pg = env.process_group - if pg is not None: - plan = planner.collective_plan(module, sharders, pg) - else: + # For heterogenous sharding, only generate the plan, no broadcast plan = planner.plan(module, sharders) + else: + env = cast(ShardingEnv, env) + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=get_local_size(env.world_size), + world_size=env.world_size, + compute_device=device.type, + ) + ) + pg = env.process_group + if pg is not None: + plan = planner.collective_plan(module, sharders, pg) + else: + plan = planner.plan(module, sharders) if type(module) in sharder_map: # If the top level module is itself a shardable module, return the sharded variant. diff --git a/torchrec/distributed/tests/test_infer_hetero_shardings.py b/torchrec/distributed/tests/test_infer_hetero_shardings.py index 7c481eb86..748defefa 100755 --- a/torchrec/distributed/tests/test_infer_hetero_shardings.py +++ b/torchrec/distributed/tests/test_infer_hetero_shardings.py @@ -16,7 +16,10 @@ from torchrec.distributed.planner import ParameterConstraints from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.planner.types import Topology -from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder +from torchrec.distributed.quant_embedding import ( + QuantEmbeddingCollectionSharder, + ShardedQuantEmbeddingCollection, +) from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.shard import _shard_modules from torchrec.distributed.sharding_plan import ( @@ -196,6 +199,11 @@ def test_sharder_different_world_sizes_for_qebc(self) -> None: == env.world_size ) + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) def test_cpu_gpu_sharding_autoplanner(self) -> None: num_embeddings = 10 emb_dim = 16 @@ -266,3 +274,68 @@ def test_cpu_gpu_sharding_autoplanner(self) -> None: .type, "cuda", ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs available", + ) + def test_cpu_gpu_sharding_shard_modules(self) -> None: + num_embeddings = 10 + emb_dim = 16 + tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=tables, + device=torch.device("cpu"), + ) + ) + ) + non_sharded_model = quantize( + model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=torch.qint8, + ) + sharder = QuantEmbeddingCollectionSharder() + env_dict = { + "cpu": ShardingEnv.from_local( + 3, + 0, + ), + "cuda": ShardingEnv.from_local( + 2, + 0, + ), + } + + shard_model = _shard_modules( + module=non_sharded_model, + env=env_dict, + # pyre-ignore + sharders=[sharder], + device=torch.device("cpu"), + ) + + self.assertTrue( + isinstance( + shard_model._module_kjt_input[0], ShardedQuantEmbeddingCollection + ) + ) + + self.assertEqual(len(shard_model._module_kjt_input[0]._lookups), 1) + self.assertEqual( + len( + shard_model._module_kjt_input[0]._lookups[0]._embedding_lookups_per_rank + ), + env_dict["cpu"].world_size, + )