From c523c19f6c4ce7902857a6e73be085ab194c5733 Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Tue, 2 Apr 2024 17:24:58 -0700 Subject: [PATCH] make collective_plan() follow c10d convention (#1842) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1842 When `pg` is None, rather than throwing, we should fallback to using the default pg. This is the general behavior accepted in OSS c10d. Reviewed By: colin2328 Differential Revision: D55613401 fbshipit-source-id: ec1502c468d8e52d2d56e0472ef4bde558e171d8 --- torchrec/distributed/planner/planners.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index c974f29e2..59cfd758f 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -61,6 +61,7 @@ ShardingType, ShardMetadata, ) +from torchrec.distributed.utils import none_throws def _to_sharding_plan( @@ -181,12 +182,19 @@ def collective_plan( self, module: nn.Module, sharders: Optional[List[ModuleSharder[nn.Module]]] = None, - pg: Optional[dist.ProcessGroup] = dist.GroupMember.WORLD, + pg: Optional[dist.ProcessGroup] = None, ) -> ShardingPlan: """ Call self.plan(...) on rank 0 and broadcast """ - assert pg is not None, "Process group is not initialized" + if pg is None: + assert dist.is_initialized(), ( + "The default process group is not yet initialized. " + "Please call torch.distributed.init_process_group() first before invoking this. " + "If you are not within a distributed environment, use the single rank version plan() instead." + ) + pg = none_throws(dist.GroupMember.WORLD) + if sharders is None: sharders = get_default_sharders() return invoke_on_rank_and_broadcast_result(