Skip to content

Commit

Permalink
make collective_plan() follow c10d convention (#1842)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1842

When `pg` is None, rather than throwing, we should fallback to using the default pg.
This is the general behavior accepted in OSS c10d.

Reviewed By: colin2328

Differential Revision: D55613401

fbshipit-source-id: ec1502c468d8e52d2d56e0472ef4bde558e171d8
  • Loading branch information
xunnanxu authored and facebook-github-bot committed Apr 3, 2024
1 parent 2bd3703 commit c523c19
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ShardingType,
ShardMetadata,
)
from torchrec.distributed.utils import none_throws


def _to_sharding_plan(
Expand Down Expand Up @@ -181,12 +182,19 @@ def collective_plan(
self,
module: nn.Module,
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
pg: Optional[dist.ProcessGroup] = dist.GroupMember.WORLD,
pg: Optional[dist.ProcessGroup] = None,
) -> ShardingPlan:
"""
Call self.plan(...) on rank 0 and broadcast
"""
assert pg is not None, "Process group is not initialized"
if pg is None:
assert dist.is_initialized(), (
"The default process group is not yet initialized. "
"Please call torch.distributed.init_process_group() first before invoking this. "
"If you are not within a distributed environment, use the single rank version plan() instead."
)
pg = none_throws(dist.GroupMember.WORLD)

if sharders is None:
sharders = get_default_sharders()
return invoke_on_rank_and_broadcast_result(
Expand Down

0 comments on commit c523c19

Please sign in to comment.