Skip to content

Commit

Permalink
register a custom op to keep the EBC module unflattened when torch.ex…
Browse files Browse the repository at this point in the history
…port

Differential Revision: D56339251
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Apr 19, 2024
1 parent ef59127 commit 92b3ba3
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 1 deletion.
84 changes: 83 additions & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
EmbeddingConfig,
pooling_type_to_str,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import (
is_non_strict_exporting,
JaggedTensor,
KeyedJaggedTensor,
KeyedTensor,
)

lib = Library("custom", "FRAGMENT")

Expand All @@ -41,6 +46,81 @@ class OpRegistryState:
operator_registry_state = OpRegistryState()


def register_ebc_op(
ebc: "EmbeddingBagCollectionInterface",
) -> str:
"""
Register EBC operator.
Args:
ebc (EmbeddingBagCollection): EBC instance.
"""

global operator_registry_state

op_name = f"{type(ebc).__name__}_{hash(ebc)}"
with operator_registry_state.op_registry_lock:
if op_name in operator_registry_state.op_registry_schema:
return op_name

dim: int = sum(ebc._lengths_per_embedding)

def ebc_op(
values: List[Optional[torch.Tensor]],
batch_size: int,
) -> torch.Tensor:
device = None
for v in values:
if v is not None:
device = v.device
break
else:
raise AssertionError(
"PooledEmbeddingArch op expects at least one of "
"id_list_features or id_score_list_features"
)
return torch.empty(batch_size, dim, device=device)

schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor"

with operator_registry_state.op_registry_lock:
if op_name in operator_registry_state.op_registry_schema:
return op_name
operator_registry_state.op_registry_schema[op_name] = schema_string
# Register schema
lib.define(schema_string)

# Register implementation
lib.impl(op_name, ebc_op, "CPU")
lib.impl(op_name, ebc_op, "CUDA")

# Register meta formula
lib.impl(op_name, ebc_op, "Meta")

return op_name


def _forward_meta(
ebc: "EmbeddingBagCollectionInterface",
features: KeyedJaggedTensor,
) -> KeyedTensor:
ebc_op_name: str = register_ebc_op(ebc)
batch_size = features.stride()

arg_list = [
features.values(),
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
]
outputs = getattr(torch.ops.custom, ebc_op_name)(arg_list, batch_size)
return KeyedTensor(
keys=ebc._embedding_names,
values=outputs,
length_per_key=ebc._lengths_per_embedding,
)


@torch.fx.wrap
def reorder_inverse_indices(
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
Expand Down Expand Up @@ -232,6 +312,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
Returns:
KeyedTensor
"""
if is_non_strict_exporting() and not torch.jit.is_scripting():
return _forward_meta(self, features)
flat_feature_names: List[str] = []
for names in self._feature_names:
flat_feature_names.extend(names)
Expand Down
101 changes: 101 additions & 0 deletions torchrec/modules/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,38 @@ def test_weighted(self) -> None:
self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"])
self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10])

def test_forward_with_meta_device(self) -> None:
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"]
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
ebc = EmbeddingBagCollection(
tables=[eb1_config, eb2_config],
is_weighted=True,
device=torch.device("meta"),
)

features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3", "f2"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7], device="meta"),
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12], device="meta"),
weights=torch.tensor(
[0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7],
device="meta",
),
)

pooled_embeddings = ebc(features)
self.assertEqual(pooled_embeddings.values().size(), (2, 10))
self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"])
self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10])
self.assertEqual(pooled_embeddings.values().device, torch.device("meta"))

def test_fx(self) -> None:
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"]
Expand Down Expand Up @@ -195,6 +227,75 @@ def test_device(self) -> None:
self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device)
self.assertEqual(torch.device("cpu"), ebc.device)

def test_exporting(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
eb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1", "f3"],
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
eb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1", "f2"],
)
eb4_config = EmbeddingBagConfig(
name="t4",
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
self.ebc1 = EmbeddingBagCollection(
tables=[eb1_config, eb2_config], is_weighted=True
)
self.ebc2 = EmbeddingBagCollection(
tables=[eb3_config, eb4_config], is_weighted=True
)

def forward(
self,
features: KeyedJaggedTensor,
) -> torch.Tensor:
embeddings1 = self.ebc1(features)
embeddings2 = self.ebc2(features)
return torch.concat([embeddings1.values(), embeddings2.values()], dim=1)

features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3", "f2"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7]),
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12]),
weights=torch.tensor(
[0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7]
),
)

m = MyModule()
ep = torch.export.export(
m,
(features,),
{},
strict=False,
)
self.assertEqual(
sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes),
0,
)
self.assertEqual(
sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes),
2,
"Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph",
)


class EmbeddingCollectionTest(unittest.TestCase):
def test_forward(self) -> None:
Expand Down

0 comments on commit 92b3ba3

Please sign in to comment.