diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 315c37b90..2b3df0d1b 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -20,7 +20,12 @@ EmbeddingConfig, pooling_type_to_str, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import ( + is_non_strict_exporting, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) lib = Library("custom", "FRAGMENT") @@ -41,6 +46,81 @@ class OpRegistryState: operator_registry_state = OpRegistryState() +def register_ebc_op( + ebc: "EmbeddingBagCollectionInterface", +) -> str: + """ + Register EBC operator. + + Args: + ebc (EmbeddingBagCollection): EBC instance. + """ + + global operator_registry_state + + op_name = f"{type(ebc).__name__}_{hash(ebc)}" + with operator_registry_state.op_registry_lock: + if op_name in operator_registry_state.op_registry_schema: + return op_name + + dim: int = sum(ebc._lengths_per_embedding) + + def ebc_op( + values: List[Optional[torch.Tensor]], + batch_size: int, + ) -> torch.Tensor: + device = None + for v in values: + if v is not None: + device = v.device + break + else: + raise AssertionError( + "PooledEmbeddingArch op expects at least one of " + "id_list_features or id_score_list_features" + ) + return torch.empty(batch_size, dim, device=device) + + schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor" + + with operator_registry_state.op_registry_lock: + if op_name in operator_registry_state.op_registry_schema: + return op_name + operator_registry_state.op_registry_schema[op_name] = schema_string + # Register schema + lib.define(schema_string) + + # Register implementation + lib.impl(op_name, ebc_op, "CPU") + lib.impl(op_name, ebc_op, "CUDA") + + # Register meta formula + lib.impl(op_name, ebc_op, "Meta") + + return op_name + + +def _forward_meta( + ebc: "EmbeddingBagCollectionInterface", + features: KeyedJaggedTensor, +) -> KeyedTensor: + ebc_op_name: str = register_ebc_op(ebc) + batch_size = features.stride() + + arg_list = [ + features.values(), + features.weights_or_none(), + features.lengths_or_none(), + features.offsets_or_none(), + ] + outputs = getattr(torch.ops.custom, ebc_op_name)(arg_list, batch_size) + return KeyedTensor( + keys=ebc._embedding_names, + values=outputs, + length_per_key=ebc._lengths_per_embedding, + ) + + @torch.fx.wrap def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], @@ -232,6 +312,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: Returns: KeyedTensor """ + if is_non_strict_exporting() and not torch.jit.is_scripting(): + return _forward_meta(self, features) flat_feature_names: List[str] = [] for names in self._feature_names: flat_feature_names.extend(names) diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py index 62338bc10..934bdb229 100644 --- a/torchrec/modules/tests/test_embedding_modules.py +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -128,6 +128,38 @@ def test_weighted(self) -> None: self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + def test_forward_with_meta_device(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection( + tables=[eb1_config, eb2_config], + is_weighted=True, + device=torch.device("meta"), + ) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7], device="meta"), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12], device="meta"), + weights=torch.tensor( + [0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7], + device="meta", + ), + ) + + pooled_embeddings = ebc(features) + self.assertEqual(pooled_embeddings.values().size(), (2, 10)) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + self.assertEqual(pooled_embeddings.values().device, torch.device("meta")) + def test_fx(self) -> None: eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] @@ -195,6 +227,75 @@ def test_device(self) -> None: self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device) self.assertEqual(torch.device("cpu"), ebc.device) + def test_exporting(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + eb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1", "f3"], + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + eb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1", "f2"], + ) + eb4_config = EmbeddingBagConfig( + name="t4", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + self.ebc1 = EmbeddingBagCollection( + tables=[eb1_config, eb2_config], is_weighted=True + ) + self.ebc2 = EmbeddingBagCollection( + tables=[eb3_config, eb4_config], is_weighted=True + ) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> torch.Tensor: + embeddings1 = self.ebc1(features) + embeddings2 = self.ebc2(features) + return torch.concat([embeddings1.values(), embeddings2.values()], dim=1) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12]), + weights=torch.tensor( + [0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7] + ), + ) + + m = MyModule() + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes), + 2, + "Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph", + ) + class EmbeddingCollectionTest(unittest.TestCase): def test_forward(self) -> None: