Skip to content

Commit

Permalink
OSS serialization with dataclasses (#1860)
Browse files Browse the repository at this point in the history
Summary:

Support unsharded TorchRec module serialization/deserialization in OSS with JSON serializer for EBC

Differential Revision: D55901896
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 19, 2024
1 parent 8fbb128 commit 76cdb5c
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 0 deletions.
32 changes: 32 additions & 0 deletions torchrec/ir/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List, Optional

from torchrec.modules.embedding_configs import DataType, PoolingType


# Same as EmbeddingBagConfig but serializable
@dataclass
class EmbeddingBagConfigMetadata:
num_embeddings: int
embedding_dim: int
name: str
data_type: DataType
feature_names: List[str]
weight_init_max: Optional[float]
weight_init_min: Optional[float]
need_pos: bool
pooling: PoolingType


@dataclass
class EBCMetadata:
tables: List[EmbeddingBagConfigMetadata]
is_weighted: bool
device: Optional[str]
142 changes: 142 additions & 0 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from typing import Dict, Type

import torch

from torch import nn
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata

from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection


def embedding_bag_config_to_metadata(
table_config: EmbeddingBagConfig,
) -> EmbeddingBagConfigMetadata:
return EmbeddingBagConfigMetadata(
num_embeddings=table_config.num_embeddings,
embedding_dim=table_config.embedding_dim,
name=table_config.name,
data_type=table_config.data_type.value,
feature_names=table_config.feature_names,
weight_init_max=table_config.weight_init_max,
weight_init_min=table_config.weight_init_min,
need_pos=table_config.need_pos,
pooling=table_config.pooling.value,
)


def embedding_metadata_to_config(
table_config: EmbeddingBagConfigMetadata,
) -> EmbeddingBagConfig:
return EmbeddingBagConfig(
num_embeddings=table_config.num_embeddings,
embedding_dim=table_config.embedding_dim,
name=table_config.name,
data_type=DataType(table_config.data_type),
feature_names=table_config.feature_names,
weight_init_max=table_config.weight_init_max,
weight_init_min=table_config.weight_init_min,
need_pos=table_config.need_pos,
pooling=PoolingType(table_config.pooling),
)


class EBCJsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
"""

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
if not isinstance(module, EmbeddingBagCollection):
raise ValueError(
f"Expected module to be of type EmbeddingBagCollection, got {type(module)}"
)

ebc_metadata = EBCMetadata(
tables=[
embedding_bag_config_to_metadata(table_config)
for table_config in module.embedding_bag_configs()
],
is_weighted=module.is_weighted(),
device=str(module.device),
)

ebc_metadata_dict = ebc_metadata.__dict__
ebc_metadata_dict["tables"] = [
table_config.__dict__ for table_config in ebc_metadata_dict["tables"]
]

return torch.frombuffer(
json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8
)

@classmethod
def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module:
if typename != "EmbeddingBagCollection":
raise ValueError(
f"Expected typename to be EmbeddingBagCollection, got {typename}"
)

raw_bytes = input.numpy().tobytes()
ebc_metadata_dict = json.loads(raw_bytes.decode())
tables = [
EmbeddingBagConfigMetadata(**table_config)
for table_config in ebc_metadata_dict["tables"]
]

return EmbeddingBagCollection(
tables=[
embedding_metadata_to_config(table_config) for table_config in tables
],
is_weighted=ebc_metadata_dict["is_weighted"],
device=(
torch.device(ebc_metadata_dict["device"])
if ebc_metadata_dict["device"]
else None
),
)


class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
"""

module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
"EmbeddingBagCollection": EBCJsonSerializer,
}

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
typename = type(module).__name__
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].serialize(module)

@classmethod
def deserialize(cls, input: torch.Tensor, typename: str) -> nn.Module:
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].deserialize(input, typename)
112 changes: 112 additions & 0 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

import unittest

import torch
from torch import nn
from torchrec.ir.serializer import JsonSerializer

from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules

from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor


class TestJsonSerializer(unittest.TestCase):
def generate_model(self) -> nn.Module:
class Model(nn.Module):
def __init__(self, ebc):
super().__init__()
self.sparse_arch = ebc

def forward(
self,
features: KeyedJaggedTensor,
) -> KeyedTensor:
return self.sparse_arch(features)

tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1"],
)
tb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)

ebc = EmbeddingBagCollection(
tables=[tb1_config, tb2_config],
is_weighted=False,
)

model = Model(ebc)

return model

def test_serialize_deserialize_ebc(self) -> None:
model = self.generate_model()
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2"],
values=torch.tensor([0, 1, 2, 3]),
offsets=torch.tensor([0, 2, 2, 3, 4]),
)

eager_kt = model(id_list_features)

# Serialize PEA
serialize_embedding_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=("sparse_arch",),
)

# Run forward on ExportedProgram
ep_output = ep.module()(id_list_features)

self.assertTrue(isinstance(ep_output, KeyedTensor))
self.assertEqual(eager_kt.keys(), ep_output.keys())
self.assertEqual(eager_kt.values().shape, ep_output.values().shape)

# Deserialize EBC
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

self.assertTrue(
isinstance(deserialized_model.sparse_arch, EmbeddingBagCollection)
)

for deserialized_config, org_config in zip(
deserialized_model.sparse_arch.embedding_bag_configs(),
model.sparse_arch.embedding_bag_configs(),
):
self.assertEqual(deserialized_config.name, org_config.name)
self.assertEqual(
deserialized_config.embedding_dim, org_config.embedding_dim
)
self.assertEqual(
deserialized_config.num_embeddings, org_config.num_embeddings
)
self.assertEqual(
deserialized_config.feature_names, org_config.feature_names
)

# Run forward on deserialized model
deserialized_kt = deserialized_model(id_list_features)

self.assertEqual(eager_kt.keys(), deserialized_kt.keys())
self.assertEqual(eager_kt.values().shape, deserialized_kt.values().shape)

0 comments on commit 76cdb5c

Please sign in to comment.