diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 59cfd758f..725a7fbdc 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -108,6 +108,31 @@ def _to_sharding_plan( return ShardingPlan(plan) +def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan: + if len(best_plans) == 1: + return best_plans[0] + else: + merged_plan = ShardingPlan({}) + for plan in best_plans: + for name, ps in plan.plan.items(): + ps = cast(EmbeddingModuleShardingPlan, ps) + if name not in merged_plan.plan: + merged_plan.plan[name] = ps + else: + for k, v in ps.items(): + cur_plan = cast( + EmbeddingModuleShardingPlan, merged_plan.plan[name] + ) + if k not in cur_plan: + cur_plan[k] = v + else: + raise PlannerError( + "table can not be sharded between two device group" + ) + + return merged_plan + + class EmbeddingShardingPlanner(ShardingPlanner): """ Provides an optimized sharding plan for a given module with shardable parameters @@ -210,7 +235,6 @@ def plan( module: nn.Module, sharders: List[ModuleSharder[nn.Module]], ) -> ShardingPlan: - self._num_proposals = 0 self._num_plans = 0 start_time = perf_counter() @@ -397,3 +421,350 @@ def plan( + no_plan_solution + last_planner_error_info, ) + + +class HeteroEmbeddingShardingPlanner(ShardingPlanner): + """ + Provides an optimized sharding plan for a given module with shardable parameters + according to the provided sharders, topology, and constraints. + """ + + def __init__( + self, + topology_groups: Optional[Dict[str, Topology]] = None, + batch_size: Optional[int] = None, + enumerators: Optional[Dict[str, Enumerator]] = None, + storage_reservations: Optional[Dict[str, StorageReservation]] = None, + proposers: Optional[Dict[str, Union[Proposer, List[Proposer]]]] = None, + partitioners: Optional[Dict[str, Partitioner]] = None, + performance_models: Optional[Dict[str, PerfModel]] = None, + stats: Optional[Dict[str, Union[Stats, List[Stats]]]] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + debug: bool = True, + ) -> None: + default_device = "cuda" if torch.cuda.is_available() else "cpu" + if topology_groups is None: + topology_groups = { + default_device: Topology( + local_world_size=get_local_size(), + world_size=dist.get_world_size(), + compute_device=default_device, + ) + } + self._topology_groups: Dict[str, Topology] = topology_groups + self._batch_size: int = batch_size if batch_size else BATCH_SIZE + self._constraints = constraints + # pyre-ignore + self._enumerators: Dict[str, Enumerator] = ( + enumerators + if enumerators + else { + group: EmbeddingEnumerator( + topology=self._topology_groups[group], + batch_size=self._batch_size, + constraints=constraints, + ) + for group in self._topology_groups.keys() + } + ) + # pyre-ignore + self._storage_reservations: Dict[str, StorageReservation] = ( + storage_reservations + if storage_reservations + else { + group: HeuristicalStorageReservation(percentage=0.15) + for group in self._topology_groups.keys() + } + ) + + # pyre-ignore + self._partitioners: Dict[str, Partitioner] = ( + partitioners + if partitioners + else { + group: GreedyPerfPartitioner() for group in self._topology_groups.keys() + } + ) + + if proposers: + # pyre-ignore + self._proposers: Dict[str, List[Proposer]] = proposers + else: + # pyre-ignore + self._proposers = { + group: [ + GridSearchProposer(), + GreedyProposer(), + GreedyProposer(use_depth=False), + UniformProposer(), + ] + for group in self._topology_groups.keys() + } + + # pyre-ignore + self._perf_models: Dict[str, PerfModel] = ( + performance_models + if performance_models + else { + group: NoopPerfModel(topology=self._topology_groups[group]) + for group in self._topology_groups + } + ) + + self._stats: Dict[str, List[Stats]] = {} + + if stats is not None: + # pyre-ignore [8] + self._stats = stats + else: + # pyre-ignore [8] + self._stats = { + group: [EmbeddingStats()] for group in self._topology_groups.keys() + } + + self._debug = debug + self._num_proposals: int = 0 + self._num_plans: int = 0 + self._best_plan: Optional[List[ShardingOption]] = None + + def collective_plan( + self, + module: nn.Module, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + pg: Optional[dist.ProcessGroup] = dist.GroupMember.WORLD, + ) -> ShardingPlan: + """ + Call self.plan(...) on rank 0 and broadcast + """ + if pg is None: + assert dist.is_initialized(), ( + "The default process group is not yet initialized. " + "Please call torch.distributed.init_process_group() first before invoking this. " + "If you are not within a distributed environment, use the single rank version plan() instead." + ) + pg = none_throws(dist.GroupMember.WORLD) + assert len(self._topology_groups) == 1, "Only single topology is supported" + + if sharders is None: + sharders = get_default_sharders() + return invoke_on_rank_and_broadcast_result( + pg, + 0, + self.plan, + module, + sharders, + ) + + def plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> ShardingPlan: + best_plans: List[ShardingPlan] = [] + for group, topology in self._topology_groups.items(): + self._num_proposals = 0 + self._num_plans = 0 + start_time = perf_counter() + best_plan = None + lowest_storage = Storage(MAX_SIZE, MAX_SIZE) + last_planner_error: Optional[PlannerError] = None + last_proposal: List[ShardingOption] = [] + best_perf_rating = MAX_SIZE + + storage_constraint: Topology = self._storage_reservations[group].reserve( + topology=topology, + batch_size=self._batch_size, + module=module, + sharders=sharders, + constraints=self._constraints, + ) + + search_space = self._enumerators[group].enumerate( + module=module, + sharders=sharders, + ) + + # filter by device group + search_space = [ + s_o + for s_o in search_space + # pyre-ignore [16] + if self._constraints[s_o.name].device_group == group + ] + + if not search_space: + # No shardable parameters + return ShardingPlan({}) + + proposal_cache: Dict[ + Tuple[int, ...], + Tuple[bool, Optional[List[ShardingOption]], Optional[float]], + ] = {} + + for proposer in self._proposers[group]: + proposer.load( + search_space=search_space, enumerator=self._enumerators[group] + ) + + for proposer in self._proposers[group]: + proposal = proposer.propose() + + while proposal: + proposal_key = tuple(sorted(map(hash, proposal))) + if proposal_key in proposal_cache: + partitionable, plan, perf_rating = proposal_cache[proposal_key] + proposer.feedback( + partitionable=partitionable, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + continue + + self._num_proposals += 1 + try: + # plan is just proposal where shard.rank is populated + plan = self._partitioners[group].partition( + proposal=proposal, + storage_constraint=storage_constraint, + ) + self._num_plans += 1 + perf_rating = self._perf_models[group].rate(plan=plan) + if perf_rating < best_perf_rating: + best_perf_rating = perf_rating + best_plan = copy.deepcopy(plan) + proposal_cache[proposal_key] = (True, plan, perf_rating) + proposer.feedback( + partitionable=True, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + except PlannerError as planner_error: + last_planner_error = planner_error + # shallow copy of the proposal + last_proposal: List[ShardingOption] = copy.copy(proposal) + current_storage = cast( + Storage, + reduce( + lambda x, y: x + y, + [ + shard.storage + for option in proposal + for shard in option.shards + ], + ), + ) + if current_storage < lowest_storage: + lowest_storage = current_storage + proposal_cache[proposal_key] = (False, proposal, None) + proposer.feedback( + partitionable=False, + plan=proposal, + storage_constraint=storage_constraint, + ) + + # clear shard.rank for each sharding_option + reset_shard_rank(proposal) + proposal = proposer.propose() + + if best_plan: + self._best_plan = best_plan + sharding_plan = _to_sharding_plan( + best_plan, self._topology_groups[group] + ) + best_plans.append(sharding_plan) + + end_time = perf_counter() + for stats in self._stats[group]: + stats.log( + sharding_plan=sharding_plan, + topology=self._topology_groups[group], + batch_size=self._batch_size, + storage_reservation=self._storage_reservations[group], + num_proposals=self._num_proposals, + num_plans=self._num_plans, + run_time=end_time - start_time, + best_plan=best_plan, + constraints=self._constraints, + sharders=sharders, + debug=self._debug, + ) + else: + global_storage_capacity = reduce( + lambda x, y: x + y, + [device.storage for device in self._topology_groups[group].devices], + ) + global_storage_constraints = reduce( + lambda x, y: x + y, + [device.storage for device in storage_constraint.devices], + ) + storage_reservation_solution = ( + ( + # pyre-ignore [16] + f"\n\t Storage reservation percentage: {self._storage_reservations[group]._percentage}, " + f"\n\t Per rank reservation for dense storage: {storage_repr_in_gb(self._storage_reservations[group]._dense_storage)}, " + f"\n\t Per rank reservation for kjt storage: {storage_repr_in_gb(self._storage_reservations[group]._kjt_storage)}, " # pyre-ignore[16] + ) + if isinstance( + self._storage_reservations[group], HeuristicalStorageReservation + ) + else f"\n\t Storage reservation percentage: {self._storage_reservation._percentage}, " # pyre-ignore[16] + ) + no_plan_solution = ( + f"Planner evaluated {self._num_proposals} proposals." + "\nPossible solutions:" + f"\n 1) Increase the number of devices ({self._topology_groups[group].world_size})" + f"\n 2) Reduce the model size (" + f"\n\t Global storage: {round(bytes_to_gb(global_storage_capacity.hbm), 3)} GB, " + f"\n\t Per rank hardware memory: {storage_repr_in_gb(self._topology_groups[group].devices[0].storage)}, " + f"{storage_reservation_solution}" + f"\n\t Global storage available for model parallel: {storage_repr_in_gb(global_storage_constraints)}, " + f"\n\t Global storage requirement for model parallel: {storage_repr_in_gb(lowest_storage)})" + f"\n 3) Reduce local batch size ({self._batch_size})" + "\n 4) Remove planner constraints that might be reducing search space or available storage\n" + ) + last_planner_error_info = ( + f"Last planner error: \n\t{last_planner_error}\n" + ) + + # printout stats for no plan situation + end_time = perf_counter() + sharding_plan = ShardingPlan(plan={}) + # force all shards to have rank= -1 + for sharding_option in last_proposal: + for shard in sharding_option.shards: + shard.rank = -1 + + for stats in self._stats[group]: + stats.log( + sharding_plan=sharding_plan, + topology=self._topology_groups[group], + batch_size=self._batch_size, + storage_reservation=self._storage_reservation, + num_proposals=self._num_proposals, + num_plans=self._num_plans, + run_time=end_time - start_time, + best_plan=last_proposal, + constraints=self._constraints, + sharders=sharders, + debug=self._debug, + ) + + if not lowest_storage.fits_in(global_storage_constraints): + raise PlannerError( + error_type=PlannerErrorType.INSUFFICIENT_STORAGE, + message="Unable to find a plan for this model because of insufficient storage. \n" + + no_plan_solution + + last_planner_error_info, + ) + else: + raise PlannerError( + error_type=PlannerErrorType.STRICT_CONSTRAINTS, + message="Unable to find a plan for this model because of the strict constraints. \n" + + no_plan_solution + + last_planner_error_info, + ) + + return _merge_plans(best_plans) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 6c16fb4bc..8bc071425 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -445,6 +445,7 @@ class ParameterConstraints: bounds_check_mode: Optional[BoundsCheckMode] = None feature_names: Optional[List[str]] = None output_dtype: Optional[DataType] = None + device_group: Optional[str] = None class PlannerErrorType(Enum): diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 30568df5f..9ed75fb4f 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -656,7 +656,6 @@ def shard_qebc( kernel_type=EmbeddingComputeKernel.QUANT.value, shardable_params=[table.name for table in mi.tables], ) - if not plan: # pyre-ignore plan = mi.planner.plan( diff --git a/torchrec/distributed/tests/test_infer_hetero_shardings.py b/torchrec/distributed/tests/test_infer_hetero_shardings.py index f3c5000ee..7c481eb86 100755 --- a/torchrec/distributed/tests/test_infer_hetero_shardings.py +++ b/torchrec/distributed/tests/test_infer_hetero_shardings.py @@ -13,6 +13,9 @@ import torch from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner +from torchrec.distributed.planner.types import Topology from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.shard import _shard_modules @@ -192,3 +195,74 @@ def test_sharder_different_world_sizes_for_qebc(self) -> None: ) == env.world_size ) + + def test_cpu_gpu_sharding_autoplanner(self) -> None: + num_embeddings = 10 + emb_dim = 16 + tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(3) + ] + model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + EmbeddingCollection( + tables=tables, + device=torch.device("cpu"), + ) + ) + ) + non_sharded_model = quantize( + model, + inplace=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=torch.qint8, + ) + sharder = QuantEmbeddingCollectionSharder() + topo_cpu = Topology(world_size=3, compute_device="cpu") + topo_gpu = Topology(world_size=2, compute_device="cuda") + topo_groups = { + "cpu": topo_cpu, + "cuda": topo_gpu, + } + constraints = { + "table_0": ParameterConstraints(device_group="cpu"), + "table_1": ParameterConstraints(device_group="cuda"), + "table_2": ParameterConstraints(device_group="cuda"), + } + planner = HeteroEmbeddingShardingPlanner( + topology_groups=topo_groups, constraints=constraints + ) + module_plan = planner.plan( + non_sharded_model, + # pyre-ignore + sharders=[sharder], + ) + print(module_plan) + + self.assertTrue( + # pyre-ignore + module_plan.plan["_module_kjt_input.0"]["table_0"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cpu", + ) + self.assertTrue( + module_plan.plan["_module_kjt_input.0"]["table_1"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cuda", + ) + self.assertTrue( + module_plan.plan["_module_kjt_input.0"]["table_2"] + .sharding_spec.shards[0] + .placement.device() + .type, + "cuda", + )