From bd26adc846832b670f3594c4f0d365270de1e6ce Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Fri, 19 Apr 2024 13:03:49 -0700 Subject: [PATCH] fix torch package issue caused by heterogenous planner (#1905) Summary: Fix torch package issue caused context S410933 Reviewed By: IvanKobzarev Differential Revision: D56360099 --- torchrec/distributed/planner/__init__.py | 5 +---- torchrec/distributed/shard.py | 7 ++----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/planner/__init__.py b/torchrec/distributed/planner/__init__.py index 90fed5c29..efd06bf02 100644 --- a/torchrec/distributed/planner/__init__.py +++ b/torchrec/distributed/planner/__init__.py @@ -21,9 +21,6 @@ - automatically building and selecting an optimized sharding plan. """ -from torchrec.distributed.planner.planners import ( # noqa - EmbeddingShardingPlanner, - HeteroEmbeddingShardingPlanner, # noqa -) +from torchrec.distributed.planner.planners import EmbeddingShardingPlanner # noqa from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 370bdebc9..20906c19f 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -15,11 +15,8 @@ from torch.distributed._composable.contract import contract from torchrec.distributed.comm import get_local_size from torchrec.distributed.model_parallel import get_default_sharders -from torchrec.distributed.planner import ( - EmbeddingShardingPlanner, - HeteroEmbeddingShardingPlanner, - Topology, -) +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.sharding_plan import ( get_module_to_default_sharders, ParameterShardingGenerator,