Skip to content

Commit

Permalink
Support auto planner in _shard_modules (#1839)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1839

As titled

Reviewed By: henrylhtsang

Differential Revision: D55423537

fbshipit-source-id: 167c92b44bda4b103caa4ac820f35b11850f91bd
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 10, 2024
1 parent 579fe9f commit 16bb1d4
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 23 deletions.
5 changes: 4 additions & 1 deletion torchrec/distributed/planner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
- automatically building and selecting an optimized sharding plan.
"""

from torchrec.distributed.planner.planners import EmbeddingShardingPlanner # noqa
from torchrec.distributed.planner.planners import ( # noqa
EmbeddingShardingPlanner,
HeteroEmbeddingShardingPlanner, # noqa
)
from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa
from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa
24 changes: 18 additions & 6 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ def __init__(
}
self._topology_groups: Dict[str, Topology] = topology_groups
self._batch_size: int = batch_size if batch_size else BATCH_SIZE
self._constraints = constraints
self._constraints: Dict[str, ParameterConstraints] = (
constraints if constraints else {}
)
# pyre-ignore
self._enumerators: Dict[str, Enumerator] = (
enumerators
Expand Down Expand Up @@ -584,17 +586,27 @@ def plan(
sharders=sharders,
)

# If no device_group is assigned in the constraints, use the current device group
# TODO: Create a util function to assign device_group according to empirical rules to cuda or cpu
for sharding_option in search_space:
if sharding_option.name not in self._constraints:
self._constraints[sharding_option.name] = ParameterConstraints(
device_group=group
)
elif not self._constraints[sharding_option.name].device_group:
self._constraints[sharding_option.name].device_group = group

# filter by device group
search_space = [
s_o
for s_o in search_space
# pyre-ignore [16]
if self._constraints[s_o.name].device_group == group
sharding_option
for sharding_option in search_space
if self._constraints[sharding_option.name].device_group == group
]

if not search_space:
# No shardable parameters
return ShardingPlan({})
best_plans.append(ShardingPlan({}))
continue

proposal_cache: Dict[
Tuple[int, ...],
Expand Down
46 changes: 31 additions & 15 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@

# pyre-strict

from typing import Callable, Dict, List, Optional, Type, Union
from typing import Callable, cast, Dict, List, Optional, Type, Union

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._composable.contract import contract
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_parallel import get_default_sharders
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
HeteroEmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.sharding_plan import (
get_module_to_default_sharders,
ParameterShardingGenerator,
Expand Down Expand Up @@ -223,21 +227,33 @@ def _shard_modules( # noqa: C901
}

if plan is None:
assert isinstance(
env, ShardingEnv
), "Currently hybrid sharding only support use manual sharding plan"
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(env.world_size),
world_size=env.world_size,
compute_device=device.type,
if isinstance(env, dict) and len(env) > 1: # use heterogenous sharding
planner = HeteroEmbeddingShardingPlanner(
topology_groups={
group: Topology(
local_world_size=get_local_size(cur_env.world_size),
world_size=cur_env.world_size,
compute_device=group,
)
for group, cur_env in env.items()
}
)
)
pg = env.process_group
if pg is not None:
plan = planner.collective_plan(module, sharders, pg)
else:
# For heterogenous sharding, only generate the plan, no broadcast
plan = planner.plan(module, sharders)
else:
env = cast(ShardingEnv, env)
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(env.world_size),
world_size=env.world_size,
compute_device=device.type,
)
)
pg = env.process_group
if pg is not None:
plan = planner.collective_plan(module, sharders, pg)
else:
plan = planner.plan(module, sharders)

if type(module) in sharder_map:
# If the top level module is itself a shardable module, return the sharded variant.
Expand Down
75 changes: 74 additions & 1 deletion torchrec/distributed/tests/test_infer_hetero_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from torchrec.distributed.planner import ParameterConstraints
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.planner.types import Topology
from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder
from torchrec.distributed.quant_embedding import (
QuantEmbeddingCollectionSharder,
ShardedQuantEmbeddingCollection,
)
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.shard import _shard_modules
from torchrec.distributed.sharding_plan import (
Expand Down Expand Up @@ -196,6 +199,11 @@ def test_sharder_different_world_sizes_for_qebc(self) -> None:
== env.world_size
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 3,
"Not enough GPUs available",
)
def test_cpu_gpu_sharding_autoplanner(self) -> None:
num_embeddings = 10
emb_dim = 16
Expand Down Expand Up @@ -266,3 +274,68 @@ def test_cpu_gpu_sharding_autoplanner(self) -> None:
.type,
"cuda",
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 3,
"Not enough GPUs available",
)
def test_cpu_gpu_sharding_shard_modules(self) -> None:
num_embeddings = 10
emb_dim = 16
tables = [
EmbeddingConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(3)
]
model = KJTInputWrapper(
module_kjt_input=torch.nn.Sequential(
EmbeddingCollection(
tables=tables,
device=torch.device("cpu"),
)
)
)
non_sharded_model = quantize(
model,
inplace=False,
quant_state_dict_split_scale_bias=True,
weight_dtype=torch.qint8,
)
sharder = QuantEmbeddingCollectionSharder()
env_dict = {
"cpu": ShardingEnv.from_local(
3,
0,
),
"cuda": ShardingEnv.from_local(
2,
0,
),
}

shard_model = _shard_modules(
module=non_sharded_model,
env=env_dict,
# pyre-ignore
sharders=[sharder],
device=torch.device("cpu"),
)

self.assertTrue(
isinstance(
shard_model._module_kjt_input[0], ShardedQuantEmbeddingCollection
)
)

self.assertEqual(len(shard_model._module_kjt_input[0]._lookups), 1)
self.assertEqual(
len(
shard_model._module_kjt_input[0]._lookups[0]._embedding_lookups_per_rank
),
env_dict["cpu"].world_size,
)

0 comments on commit 16bb1d4

Please sign in to comment.