diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 9245beb8f..c66bb74a4 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -48,6 +48,7 @@ try: from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling + except Exception: def is_torchdynamo_compiling() -> bool: # type: ignore[misc] @@ -63,6 +64,7 @@ def _get_recat( stagger: int = 1, device: Optional[torch.device] = None, batch_size_per_rank: Optional[List[int]] = None, + use_tensor_compute: bool = False, ) -> Optional[torch.Tensor]: """ Calculates relevant recat indices required to reorder AlltoAll collective. @@ -88,6 +90,11 @@ def _get_recat( _recat(0, 4, 2) # None """ + if use_tensor_compute: + return _get_recat_tensor_compute( + local_split, num_splits, stagger, device, batch_size_per_rank + ) + with record_function("## all2all_data:recat_permute_gen ##"): if local_split == 0: return None @@ -151,6 +158,77 @@ def _get_recat( return torch.tensor(recat, device=device, dtype=torch.int32) +def _get_recat_tensor_compute( + local_split: int, + num_splits: int, + stagger: int = 1, + device: Optional[torch.device] = None, + batch_size_per_rank: Optional[List[int]] = None, +) -> Optional[torch.Tensor]: + """ + _get_recat list based will produce many instructions in the graph with scalar compute. + This is tensor compute with identical result with smaller ops count. + """ + with record_function("## all2all_data:recat_permute_gen ##"): + if local_split == 0: + return None + + X: int = num_splits // stagger + Y: int = stagger + feature_order: torch.Tensor = ( + torch.arange(X, dtype=torch.int32).view(X, 1).expand(X, Y) + + (X * torch.arange(Y, dtype=torch.int32)).expand(X, Y) + ).reshape(-1) + + LS: int = local_split + FO_S0: int = feature_order.size(0) + recat: torch.Tensor = ( + torch.arange(LS, dtype=torch.int32).view(LS, 1).expand(LS, FO_S0) + + (feature_order.expand(LS, FO_S0) * LS) + ).reshape(-1) + + vb_condition = batch_size_per_rank is not None and any( + bs != batch_size_per_rank[0] for bs in batch_size_per_rank + ) + + if vb_condition: + batch_size_per_rank_tensor = torch._refs.tensor( + batch_size_per_rank, dtype=torch.int32 + ) + N: int = batch_size_per_rank_tensor.size(0) + batch_size_per_feature_tensor: torch.Tensor = ( + batch_size_per_rank_tensor.view(N, 1).expand(N, LS).reshape(-1) + ) + + permuted_batch_size_per_feature_tensor: torch.Tensor = ( + batch_size_per_feature_tensor.index_select(0, recat) + ) + + input_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + batch_size_per_feature_tensor + ) + output_offset: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + permuted_batch_size_per_feature_tensor + ) + + recat_tensor = torch.tensor( + recat, + device=device, + dtype=torch.int32, + ) + input_offset_device = input_offset.to(device=device) + output_offset_device = output_offset.to(device=device) + recat = torch.ops.fbgemm.expand_into_jagged_permute( + recat_tensor, + input_offset_device, + output_offset_device, + output_offset[-1].item(), + ) + return recat + else: + return torch.tensor(recat, device=device, dtype=torch.int32) + + class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]): """ Awaitable for splits AlltoAll. @@ -198,7 +276,14 @@ def _wait_impl(self) -> List[List[int]]: if not is_torchdynamo_compiling(): self._splits_awaitable.wait() - return self._output_tensor.view(self.num_workers, -1).T.tolist() + ret = self._output_tensor.view(self.num_workers, -1).T.tolist() + + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + for i in range(len(ret)): + for j in range(len(ret[i])): + torch._check_is_size(ret[i][j]) + + return ret class KJTAllToAllTensorsAwaitable(Awaitable[KeyedJaggedTensor]): @@ -258,6 +343,7 @@ def __init__( stagger=stagger, device=device, batch_size_per_rank=self._stride_per_rank, + use_tensor_compute=is_torchdynamo_compiling(), ) if self._workers == 1: return diff --git a/torchrec/distributed/planner/__init__.py b/torchrec/distributed/planner/__init__.py index 90fed5c29..efd06bf02 100644 --- a/torchrec/distributed/planner/__init__.py +++ b/torchrec/distributed/planner/__init__.py @@ -21,9 +21,6 @@ - automatically building and selecting an optimized sharding plan. """ -from torchrec.distributed.planner.planners import ( # noqa - EmbeddingShardingPlanner, - HeteroEmbeddingShardingPlanner, # noqa -) +from torchrec.distributed.planner.planners import EmbeddingShardingPlanner # noqa from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 370bdebc9..20906c19f 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -15,11 +15,8 @@ from torch.distributed._composable.contract import contract from torchrec.distributed.comm import get_local_size from torchrec.distributed.model_parallel import get_default_sharders -from torchrec.distributed.planner import ( - EmbeddingShardingPlanner, - HeteroEmbeddingShardingPlanner, - Topology, -) +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.sharding_plan import ( get_module_to_default_sharders, ParameterShardingGenerator, diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py new file mode 100644 index 000000000..27ee9c8e6 --- /dev/null +++ b/torchrec/ir/schema.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Optional + +from torchrec.modules.embedding_configs import DataType, PoolingType + + +# Same as EmbeddingBagConfig but serializable +@dataclass +class EmbeddingBagConfigMetadata: + num_embeddings: int + embedding_dim: int + name: str + data_type: DataType + feature_names: List[str] + weight_init_max: Optional[float] + weight_init_min: Optional[float] + need_pos: bool + pooling: PoolingType + + +@dataclass +class EBCMetadata: + tables: List[EmbeddingBagConfigMetadata] + is_weighted: bool + device: Optional[str] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py new file mode 100644 index 000000000..514ea501e --- /dev/null +++ b/torchrec/ir/serializer.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from typing import Dict, Type + +import torch + +from torch import nn +from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata + +from torchrec.ir.types import SerializerInterface +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +def embedding_bag_config_to_metadata( + table_config: EmbeddingBagConfig, +) -> EmbeddingBagConfigMetadata: + return EmbeddingBagConfigMetadata( + num_embeddings=table_config.num_embeddings, + embedding_dim=table_config.embedding_dim, + name=table_config.name, + data_type=table_config.data_type.value, + feature_names=table_config.feature_names, + weight_init_max=table_config.weight_init_max, + weight_init_min=table_config.weight_init_min, + need_pos=table_config.need_pos, + pooling=table_config.pooling.value, + ) + + +def embedding_metadata_to_config( + table_config: EmbeddingBagConfigMetadata, +) -> EmbeddingBagConfig: + return EmbeddingBagConfig( + num_embeddings=table_config.num_embeddings, + embedding_dim=table_config.embedding_dim, + name=table_config.name, + data_type=DataType(table_config.data_type), + feature_names=table_config.feature_names, + weight_init_max=table_config.weight_init_max, + weight_init_min=table_config.weight_init_min, + need_pos=table_config.need_pos, + pooling=PoolingType(table_config.pooling), + ) + + +class EBCJsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using thrift. + """ + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + if not isinstance(module, EmbeddingBagCollection): + raise ValueError( + f"Expected module to be of type EmbeddingBagCollection, got {type(module)}" + ) + + ebc_metadata = EBCMetadata( + tables=[ + embedding_bag_config_to_metadata(table_config) + for table_config in module.embedding_bag_configs() + ], + is_weighted=module.is_weighted(), + device=str(module.device), + ) + + ebc_metadata_dict = ebc_metadata.__dict__ + ebc_metadata_dict["tables"] = [ + table_config.__dict__ for table_config in ebc_metadata_dict["tables"] + ] + + return torch.frombuffer( + json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8 + ) + + @classmethod + def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module: + if typename != "EmbeddingBagCollection": + raise ValueError( + f"Expected typename to be EmbeddingBagCollection, got {typename}" + ) + + raw_bytes = input.numpy().tobytes() + ebc_metadata_dict = json.loads(raw_bytes.decode()) + tables = [ + EmbeddingBagConfigMetadata(**table_config) + for table_config in ebc_metadata_dict["tables"] + ] + + return EmbeddingBagCollection( + tables=[ + embedding_metadata_to_config(table_config) for table_config in tables + ], + is_weighted=ebc_metadata_dict["is_weighted"], + device=( + torch.device(ebc_metadata_dict["device"]) + if ebc_metadata_dict["device"] + else None + ), + ) + + +class JsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using thrift. + """ + + module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = { + "EmbeddingBagCollection": EBCJsonSerializer, + } + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + typename = type(module).__name__ + if typename not in cls.module_to_serializer_cls: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + + return cls.module_to_serializer_cls[typename].serialize(module) + + @classmethod + def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module: + if typename not in cls.module_to_serializer_cls: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + + return cls.module_to_serializer_cls[typename].deserialize(input, typename) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py new file mode 100644 index 000000000..6ff922c07 --- /dev/null +++ b/torchrec/ir/tests/test_serializer.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import unittest + +import torch +from torch import nn +from torchrec.ir.serializer import JsonSerializer + +from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules + +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class TestJsonSerializer(unittest.TestCase): + def generate_model(self) -> nn.Module: + class Model(nn.Module): + def __init__(self, ebc): + super().__init__() + self.sparse_arch = ebc + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + return self.sparse_arch(features) + + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection( + tables=[tb1_config, tb2_config], + is_weighted=False, + ) + + model = Model(ebc) + + return model + + def test_serialize_deserialize_ebc(self) -> None: + model = self.generate_model() + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4]), + ) + + eager_kt = model(id_list_features) + + # Serialize PEA + model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(id_list_features) + + self.assertTrue(isinstance(ep_output, KeyedTensor)) + self.assertEqual(eager_kt.keys(), ep_output.keys()) + self.assertEqual(eager_kt.values().shape, ep_output.values().shape) + + # Deserialize EBC + deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + + self.assertTrue( + isinstance(deserialized_model.sparse_arch, EmbeddingBagCollection) + ) + + for deserialized_config, org_config in zip( + deserialized_model.sparse_arch.embedding_bag_configs(), + model.sparse_arch.embedding_bag_configs(), + ): + self.assertEqual(deserialized_config.name, org_config.name) + self.assertEqual( + deserialized_config.embedding_dim, org_config.embedding_dim + ) + self.assertEqual( + deserialized_config.num_embeddings, org_config.num_embeddings + ) + self.assertEqual( + deserialized_config.feature_names, org_config.feature_names + ) + + # Run forward on deserialized model + deserialized_kt = deserialized_model(id_list_features) + + self.assertEqual(eager_kt.keys(), deserialized_kt.keys()) + self.assertEqual(eager_kt.values().shape, deserialized_kt.values().shape) diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 385db678a..bc89ab03b 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -5,11 +5,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict - #!/usr/bin/env python3 -from typing import Type +from typing import List, Tuple, Type import torch @@ -25,19 +23,34 @@ def serialize_embedding_modules( model: nn.Module, serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, -) -> nn.Module: - for _, module in model.named_modules(): +) -> Tuple[nn.Module, List[str]]: + """ + Takes all the modules that are of type `serializer_cls` and serializes them + in the given format with a registered buffer to the module. + + Returns the modified module and the list of fqns that had the buffer added. + """ + preserve_fqns = [] + for fqn, module in model.named_modules(): if type(module).__name__ in serializer_cls.module_to_serializer_cls: serialized_module = serializer_cls.serialize(module) module.register_buffer("ir_metadata", serialized_module, persistent=False) + preserve_fqns.append(fqn) - return model + return model, preserve_fqns def deserialize_embedding_modules( ep: ExportedProgram, serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, ) -> nn.Module: + """ + Takes ExportedProgram (IR) and looks for ir_metadata buffer. + If found, deserializes the buffer and replaces the module with the deserialized + module. + + Returns the unflattened ExportedProgram with the deserialized modules. + """ model = torch.export.unflatten(ep) module_type_dict = {} for node in ep.graph.nodes: @@ -62,6 +75,11 @@ def deserialize_embedding_modules( fqn_to_new_module[fqn] = deserialized_module for fqn, new_module in fqn_to_new_module.items(): - setattr(model, fqn, new_module) + # handle nested attribute like "x.y.z" + attrs = fqn.split(".") + parent = model + for a in attrs[:-1]: + parent = getattr(parent, a) + setattr(parent, attrs[-1], new_module) return model diff --git a/torchrec/models/tests/test_dlrm.py b/torchrec/models/tests/test_dlrm.py index 43808e404..9304c00a4 100644 --- a/torchrec/models/tests/test_dlrm.py +++ b/torchrec/models/tests/test_dlrm.py @@ -14,6 +14,8 @@ from torch.testing import FileCheck # @manual from torchrec.datasets.utils import Batch from torchrec.fx import symbolic_trace +from torchrec.ir.serializer import JsonSerializer +from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules from torchrec.models.dlrm import ( choose, DenseArch, @@ -1218,3 +1220,66 @@ def test_basic(self) -> None: ) self.assertEqual(logits.size(), (B, 1)) + + def test_export_serialization(self) -> None: + B = 2 + D = 8 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + model = DLRM( + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + features = torch.rand((B, 100)) + + # 0 1 + # 0 [1,2] [4,5] + # 1 [4,3] [2,9] + # ^ + # feature + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + offsets=torch.tensor([0, 2, 4, 6, 8]), + ) + + logits = model( + dense_features=features, + sparse_features=sparse_features, + ) + + self.assertEqual(logits.size(), (B, 1)) + + model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + + ep = torch.export.export( + model, + (features, sparse_features), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(features, sparse_features) + self.assertEqual(ep_output.size(), (B, 1)) + self.assertTrue(torch.allclose(logits, ep_output)) + + deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + deserialized_logits = deserialized_model(features, sparse_features) + + self.assertEqual(deserialized_logits.size(), (B, 1)) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 3670c3030..91102238b 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2519,17 +2519,22 @@ def __str__(self) -> str: def _kt_flatten( kt: KeyedTensor, -) -> Tuple[List[torch.Tensor], List[str]]: - return [torch.tensor(kt._length_per_key, dtype=torch.int64), kt._values], kt._keys +) -> Tuple[List[torch.Tensor], Tuple[List[str], List[int]]]: + return [kt._values], (kt._keys, kt._length_per_key) -def _kt_unflatten(values: List[torch.Tensor], context: List[str]) -> KeyedTensor: - return KeyedTensor(context, values[0].tolist(), values[1]) +def _kt_unflatten( + values: List[torch.Tensor], context: Tuple[List[str], List[int]] +) -> KeyedTensor: + return KeyedTensor(context[0], context[1], values[0]) def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: return _kt_flatten(kt)[0] -register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten) +# The assumption here in torch.exporting KeyedTensor is that _length_per_key is static +register_pytree_node( + KeyedTensor, _kt_flatten, _kt_unflatten, serialized_type_name="KeyedTensor" +) register_pytree_flatten_spec(KeyedTensor, _kt_flatten_spec) diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 614ce5965..67b6d77c8 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2289,6 +2289,23 @@ def test_string_values(self) -> None: """, ) + def test_pytree(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]), + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) + + flattened, out_spec = pytree.tree_flatten(kt) + + self.assertTrue(torch.equal(flattened[0], kt.values())) + unflattened = pytree.tree_unflatten(flattened, out_spec) + + self.assertTrue(isinstance(unflattened, KeyedTensor)) + self.assertListEqual(unflattened.keys(), keys) + self.assertListEqual(unflattened._length_per_key, kt._length_per_key) + class TestComputeKJTToJTDict(unittest.TestCase): def test_key_lookup(self) -> None: