Skip to content

Commit

Permalink
2024-04-20 nightly release (303e852)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Apr 20, 2024
1 parent 1490666 commit 9b4853b
Show file tree
Hide file tree
Showing 10 changed files with 493 additions and 22 deletions.
88 changes: 87 additions & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling

except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
Expand All @@ -63,6 +64,7 @@ def _get_recat(
stagger: int = 1,
device: Optional[torch.device] = None,
batch_size_per_rank: Optional[List[int]] = None,
use_tensor_compute: bool = False,
) -> Optional[torch.Tensor]:
"""
Calculates relevant recat indices required to reorder AlltoAll collective.
Expand All @@ -88,6 +90,11 @@ def _get_recat(
_recat(0, 4, 2)
# None
"""
if use_tensor_compute:
return _get_recat_tensor_compute(
local_split, num_splits, stagger, device, batch_size_per_rank
)

with record_function("## all2all_data:recat_permute_gen ##"):
if local_split == 0:
return None
Expand Down Expand Up @@ -151,6 +158,77 @@ def _get_recat(
return torch.tensor(recat, device=device, dtype=torch.int32)


def _get_recat_tensor_compute(
local_split: int,
num_splits: int,
stagger: int = 1,
device: Optional[torch.device] = None,
batch_size_per_rank: Optional[List[int]] = None,
) -> Optional[torch.Tensor]:
"""
_get_recat list based will produce many instructions in the graph with scalar compute.
This is tensor compute with identical result with smaller ops count.
"""
with record_function("## all2all_data:recat_permute_gen ##"):
if local_split == 0:
return None

X: int = num_splits // stagger
Y: int = stagger
feature_order: torch.Tensor = (
torch.arange(X, dtype=torch.int32).view(X, 1).expand(X, Y)
+ (X * torch.arange(Y, dtype=torch.int32)).expand(X, Y)
).reshape(-1)

LS: int = local_split
FO_S0: int = feature_order.size(0)
recat: torch.Tensor = (
torch.arange(LS, dtype=torch.int32).view(LS, 1).expand(LS, FO_S0)
+ (feature_order.expand(LS, FO_S0) * LS)
).reshape(-1)

vb_condition = batch_size_per_rank is not None and any(
bs != batch_size_per_rank[0] for bs in batch_size_per_rank
)

if vb_condition:
batch_size_per_rank_tensor = torch._refs.tensor(
batch_size_per_rank, dtype=torch.int32
)
N: int = batch_size_per_rank_tensor.size(0)
batch_size_per_feature_tensor: torch.Tensor = (
batch_size_per_rank_tensor.view(N, 1).expand(N, LS).reshape(-1)
)

permuted_batch_size_per_feature_tensor: torch.Tensor = (
batch_size_per_feature_tensor.index_select(0, recat)
)

input_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
batch_size_per_feature_tensor
)
output_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
permuted_batch_size_per_feature_tensor
)

recat_tensor = torch.tensor(
recat,
device=device,
dtype=torch.int32,
)
input_offset_device = input_offset.to(device=device)
output_offset_device = output_offset.to(device=device)
recat = torch.ops.fbgemm.expand_into_jagged_permute(
recat_tensor,
input_offset_device,
output_offset_device,
output_offset[-1].item(),
)
return recat
else:
return torch.tensor(recat, device=device, dtype=torch.int32)


class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]):
"""
Awaitable for splits AlltoAll.
Expand Down Expand Up @@ -198,7 +276,14 @@ def _wait_impl(self) -> List[List[int]]:
if not is_torchdynamo_compiling():
self._splits_awaitable.wait()

return self._output_tensor.view(self.num_workers, -1).T.tolist()
ret = self._output_tensor.view(self.num_workers, -1).T.tolist()

if not torch.jit.is_scripting() and is_torchdynamo_compiling():
for i in range(len(ret)):
for j in range(len(ret[i])):
torch._check_is_size(ret[i][j])

return ret


class KJTAllToAllTensorsAwaitable(Awaitable[KeyedJaggedTensor]):
Expand Down Expand Up @@ -258,6 +343,7 @@ def __init__(
stagger=stagger,
device=device,
batch_size_per_rank=self._stride_per_rank,
use_tensor_compute=is_torchdynamo_compiling(),
)
if self._workers == 1:
return
Expand Down
5 changes: 1 addition & 4 deletions torchrec/distributed/planner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
- automatically building and selecting an optimized sharding plan.
"""

from torchrec.distributed.planner.planners import ( # noqa
EmbeddingShardingPlanner,
HeteroEmbeddingShardingPlanner, # noqa
)
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner # noqa
from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa
from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa
7 changes: 2 additions & 5 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
from torch.distributed._composable.contract import contract
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_parallel import get_default_sharders
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
HeteroEmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.sharding_plan import (
get_module_to_default_sharders,
ParameterShardingGenerator,
Expand Down
32 changes: 32 additions & 0 deletions torchrec/ir/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List, Optional

from torchrec.modules.embedding_configs import DataType, PoolingType


# Same as EmbeddingBagConfig but serializable
@dataclass
class EmbeddingBagConfigMetadata:
num_embeddings: int
embedding_dim: int
name: str
data_type: DataType
feature_names: List[str]
weight_init_max: Optional[float]
weight_init_min: Optional[float]
need_pos: bool
pooling: PoolingType


@dataclass
class EBCMetadata:
tables: List[EmbeddingBagConfigMetadata]
is_weighted: bool
device: Optional[str]
142 changes: 142 additions & 0 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from typing import Dict, Type

import torch

from torch import nn
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata

from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection


def embedding_bag_config_to_metadata(
table_config: EmbeddingBagConfig,
) -> EmbeddingBagConfigMetadata:
return EmbeddingBagConfigMetadata(
num_embeddings=table_config.num_embeddings,
embedding_dim=table_config.embedding_dim,
name=table_config.name,
data_type=table_config.data_type.value,
feature_names=table_config.feature_names,
weight_init_max=table_config.weight_init_max,
weight_init_min=table_config.weight_init_min,
need_pos=table_config.need_pos,
pooling=table_config.pooling.value,
)


def embedding_metadata_to_config(
table_config: EmbeddingBagConfigMetadata,
) -> EmbeddingBagConfig:
return EmbeddingBagConfig(
num_embeddings=table_config.num_embeddings,
embedding_dim=table_config.embedding_dim,
name=table_config.name,
data_type=DataType(table_config.data_type),
feature_names=table_config.feature_names,
weight_init_max=table_config.weight_init_max,
weight_init_min=table_config.weight_init_min,
need_pos=table_config.need_pos,
pooling=PoolingType(table_config.pooling),
)


class EBCJsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
"""

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
if not isinstance(module, EmbeddingBagCollection):
raise ValueError(
f"Expected module to be of type EmbeddingBagCollection, got {type(module)}"
)

ebc_metadata = EBCMetadata(
tables=[
embedding_bag_config_to_metadata(table_config)
for table_config in module.embedding_bag_configs()
],
is_weighted=module.is_weighted(),
device=str(module.device),
)

ebc_metadata_dict = ebc_metadata.__dict__
ebc_metadata_dict["tables"] = [
table_config.__dict__ for table_config in ebc_metadata_dict["tables"]
]

return torch.frombuffer(
json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8
)

@classmethod
def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module:
if typename != "EmbeddingBagCollection":
raise ValueError(
f"Expected typename to be EmbeddingBagCollection, got {typename}"
)

raw_bytes = input.numpy().tobytes()
ebc_metadata_dict = json.loads(raw_bytes.decode())
tables = [
EmbeddingBagConfigMetadata(**table_config)
for table_config in ebc_metadata_dict["tables"]
]

return EmbeddingBagCollection(
tables=[
embedding_metadata_to_config(table_config) for table_config in tables
],
is_weighted=ebc_metadata_dict["is_weighted"],
device=(
torch.device(ebc_metadata_dict["device"])
if ebc_metadata_dict["device"]
else None
),
)


class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
"""

module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
"EmbeddingBagCollection": EBCJsonSerializer,
}

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
typename = type(module).__name__
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].serialize(module)

@classmethod
def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module:
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].deserialize(input, typename)
Loading

0 comments on commit 9b4853b

Please sign in to comment.