From 76cdb5c4383ec93a069dd57c2a4f5c191f48f7dd Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 19 Apr 2024 09:44:12 -0700 Subject: [PATCH] OSS serialization with dataclasses (#1860) Summary: Support unsharded TorchRec module serialization/deserialization in OSS with JSON serializer for EBC Differential Revision: D55901896 --- torchrec/ir/schema.py | 32 ++++++ torchrec/ir/serializer.py | 142 +++++++++++++++++++++++++++ torchrec/ir/tests/test_serializer.py | 112 +++++++++++++++++++++ 3 files changed, 286 insertions(+) create mode 100644 torchrec/ir/schema.py create mode 100644 torchrec/ir/serializer.py create mode 100644 torchrec/ir/tests/test_serializer.py diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py new file mode 100644 index 000000000..27ee9c8e6 --- /dev/null +++ b/torchrec/ir/schema.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Optional + +from torchrec.modules.embedding_configs import DataType, PoolingType + + +# Same as EmbeddingBagConfig but serializable +@dataclass +class EmbeddingBagConfigMetadata: + num_embeddings: int + embedding_dim: int + name: str + data_type: DataType + feature_names: List[str] + weight_init_max: Optional[float] + weight_init_min: Optional[float] + need_pos: bool + pooling: PoolingType + + +@dataclass +class EBCMetadata: + tables: List[EmbeddingBagConfigMetadata] + is_weighted: bool + device: Optional[str] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py new file mode 100644 index 000000000..514ea501e --- /dev/null +++ b/torchrec/ir/serializer.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from typing import Dict, Type + +import torch + +from torch import nn +from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata + +from torchrec.ir.types import SerializerInterface +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +def embedding_bag_config_to_metadata( + table_config: EmbeddingBagConfig, +) -> EmbeddingBagConfigMetadata: + return EmbeddingBagConfigMetadata( + num_embeddings=table_config.num_embeddings, + embedding_dim=table_config.embedding_dim, + name=table_config.name, + data_type=table_config.data_type.value, + feature_names=table_config.feature_names, + weight_init_max=table_config.weight_init_max, + weight_init_min=table_config.weight_init_min, + need_pos=table_config.need_pos, + pooling=table_config.pooling.value, + ) + + +def embedding_metadata_to_config( + table_config: EmbeddingBagConfigMetadata, +) -> EmbeddingBagConfig: + return EmbeddingBagConfig( + num_embeddings=table_config.num_embeddings, + embedding_dim=table_config.embedding_dim, + name=table_config.name, + data_type=DataType(table_config.data_type), + feature_names=table_config.feature_names, + weight_init_max=table_config.weight_init_max, + weight_init_min=table_config.weight_init_min, + need_pos=table_config.need_pos, + pooling=PoolingType(table_config.pooling), + ) + + +class EBCJsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using thrift. + """ + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + if not isinstance(module, EmbeddingBagCollection): + raise ValueError( + f"Expected module to be of type EmbeddingBagCollection, got {type(module)}" + ) + + ebc_metadata = EBCMetadata( + tables=[ + embedding_bag_config_to_metadata(table_config) + for table_config in module.embedding_bag_configs() + ], + is_weighted=module.is_weighted(), + device=str(module.device), + ) + + ebc_metadata_dict = ebc_metadata.__dict__ + ebc_metadata_dict["tables"] = [ + table_config.__dict__ for table_config in ebc_metadata_dict["tables"] + ] + + return torch.frombuffer( + json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8 + ) + + @classmethod + def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module: + if typename != "EmbeddingBagCollection": + raise ValueError( + f"Expected typename to be EmbeddingBagCollection, got {typename}" + ) + + raw_bytes = input.numpy().tobytes() + ebc_metadata_dict = json.loads(raw_bytes.decode()) + tables = [ + EmbeddingBagConfigMetadata(**table_config) + for table_config in ebc_metadata_dict["tables"] + ] + + return EmbeddingBagCollection( + tables=[ + embedding_metadata_to_config(table_config) for table_config in tables + ], + is_weighted=ebc_metadata_dict["is_weighted"], + device=( + torch.device(ebc_metadata_dict["device"]) + if ebc_metadata_dict["device"] + else None + ), + ) + + +class JsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using thrift. + """ + + module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = { + "EmbeddingBagCollection": EBCJsonSerializer, + } + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + typename = type(module).__name__ + if typename not in cls.module_to_serializer_cls: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + + return cls.module_to_serializer_cls[typename].serialize(module) + + @classmethod + def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module: + if typename not in cls.module_to_serializer_cls: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + + return cls.module_to_serializer_cls[typename].deserialize(input, typename) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py new file mode 100644 index 000000000..93e39641f --- /dev/null +++ b/torchrec/ir/tests/test_serializer.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import unittest + +import torch +from torch import nn +from torchrec.ir.serializer import JsonSerializer + +from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules + +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class TestJsonSerializer(unittest.TestCase): + def generate_model(self) -> nn.Module: + class Model(nn.Module): + def __init__(self, ebc): + super().__init__() + self.sparse_arch = ebc + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + return self.sparse_arch(features) + + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection( + tables=[tb1_config, tb2_config], + is_weighted=False, + ) + + model = Model(ebc) + + return model + + def test_serialize_deserialize_ebc(self) -> None: + model = self.generate_model() + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4]), + ) + + eager_kt = model(id_list_features) + + # Serialize PEA + serialize_embedding_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=("sparse_arch",), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(id_list_features) + + self.assertTrue(isinstance(ep_output, KeyedTensor)) + self.assertEqual(eager_kt.keys(), ep_output.keys()) + self.assertEqual(eager_kt.values().shape, ep_output.values().shape) + + # Deserialize EBC + deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + + self.assertTrue( + isinstance(deserialized_model.sparse_arch, EmbeddingBagCollection) + ) + + for deserialized_config, org_config in zip( + deserialized_model.sparse_arch.embedding_bag_configs(), + model.sparse_arch.embedding_bag_configs(), + ): + self.assertEqual(deserialized_config.name, org_config.name) + self.assertEqual( + deserialized_config.embedding_dim, org_config.embedding_dim + ) + self.assertEqual( + deserialized_config.num_embeddings, org_config.num_embeddings + ) + self.assertEqual( + deserialized_config.feature_names, org_config.feature_names + ) + + # Run forward on deserialized model + deserialized_kt = deserialized_model(id_list_features) + + self.assertEqual(eager_kt.keys(), deserialized_kt.keys()) + self.assertEqual(eager_kt.values().shape, deserialized_kt.values().shape)