Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add size and stride for empty shard DT #2662

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
EmbeddingShardingContext,
Expand Down Expand Up @@ -73,6 +74,7 @@
add_params_from_parameter_sharding,
append_prefix,
convert_to_fbgemm_types,
create_global_tensor_shape_stride_from_metadata,
maybe_annotate_embedding_event,
merge_fused_params,
none_throws,
Expand Down Expand Up @@ -918,6 +920,14 @@ def _initialize_torch_state(self) -> None: # noqa
)
)
else:
shape, stride = create_global_tensor_shape_stride_from_metadata(
none_throws(self.module_sharding_plan[table_name]),
(
self._env.node_group_size
if isinstance(self._env, ShardingEnv2D)
else get_local_size(self._env.world_size)
),
)
# empty shard case
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand All @@ -927,6 +937,8 @@ def _initialize_torch_state(self) -> None: # noqa
),
device_mesh=self._env.device_mesh,
run_check=False,
shape=shape,
stride=stride,
)
)
else:
Expand Down
43 changes: 42 additions & 1 deletion torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections import OrderedDict
from contextlib import AbstractContextManager, nullcontext
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union

import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType
Expand Down Expand Up @@ -511,3 +511,44 @@ def interaction(self, *args, **kwargs) -> None:
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin


def create_global_tensor_shape_stride_from_metadata(
parameter_sharding: ParameterSharding, devices_per_node: Optional[int] = None
) -> Tuple[torch.Size, Tuple[int, int]]:
"""
Create a global tensor shape and stride from shard metadata.

Returns:
torch.Size: global tensor shape.
tuple: global tensor stride.
"""
size = None
if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value:
row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] # pyre-ignore[16]
col_dim = 0
for shard in parameter_sharding.sharding_spec.shards:
col_dim += shard.shard_sizes[1]
size = torch.Size([row_dim, col_dim])
elif (
parameter_sharding.sharding_type == ShardingType.ROW_WISE.value
or parameter_sharding.sharding_type == ShardingType.TABLE_ROW_WISE.value
):
row_dim = 0
col_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[1]
for shard in parameter_sharding.sharding_spec.shards:
row_dim += shard.shard_sizes[0]
size = torch.Size([row_dim, col_dim])
elif parameter_sharding.sharding_type == ShardingType.TABLE_WISE.value:
size = torch.Size(parameter_sharding.sharding_spec.shards[0].shard_sizes)
elif parameter_sharding.sharding_type == ShardingType.GRID_SHARD.value:
# we need node group size to appropriately calculate global shape from shard
assert devices_per_node is not None
row_dim, col_dim = 0, 0
num_cw_shards = len(parameter_sharding.sharding_spec.shards) // devices_per_node
for _ in range(num_cw_shards):
col_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[1]
for _ in range(devices_per_node):
row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0]
size = torch.Size([row_dim, col_dim])
return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) # pyre-ignore[7]
Loading