From 45ecdbf59799e4ccea634b09fb91ad83f29baae2 Mon Sep 17 00:00:00 2001 From: Shuao Xiong Date: Thu, 11 Apr 2024 11:34:27 -0700 Subject: [PATCH] remove features_to_dict (#1851) Summary: Use ComputeKJTToJTDict instead of features_to_dict in the tagging rule. Differential Revision: D55842584 --- torchrec/quant/embedding_modules.py | 17 ++++++++--------- torchrec/quant/tests/test_embedding_modules.py | 18 +++++++++++------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 0f3b8b439..431ce406e 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -46,7 +46,12 @@ ) from torchrec.modules.utils import construct_jagged_tensors_inference -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import ( + ComputeKJTToJTDict, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) from torchrec.tensor_types import UInt2Tensor, UInt4Tensor from torchrec.types import ModuleNoCopyMixin @@ -230,13 +235,6 @@ def _update_embedding_configs( ) -@torch.fx.wrap -def features_to_dict( - features: KeyedJaggedTensor, -) -> Dict[str, JaggedTensor]: - return features.to_dict() - - class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin): """ EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags). @@ -332,6 +330,7 @@ def __init__( Dict[str, Tuple[Tensor, Tensor]] ] = None self.row_alignment = row_alignment + self._kjt_to_jt_dict = ComputeKJTToJTDict() table_names = set() for table in self._embedding_bag_configs: @@ -463,7 +462,7 @@ def forward( KeyedTensor """ - feature_dict = features_to_dict(features) + feature_dict = self._kjt_to_jt_dict(features) embeddings = [] # TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script. diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py index 3b5db8233..5d3db13a1 100644 --- a/torchrec/quant/tests/test_embedding_modules.py +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -29,10 +29,14 @@ from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, EmbeddingCollection as QuantEmbeddingCollection, - features_to_dict, quant_prep_enable_quant_state_dict_split_scale_bias, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import ( + ComputeKJTToJTDict, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) class EmbeddingBagCollectionTest(unittest.TestCase): @@ -445,7 +449,7 @@ def test_trace_and_script(self) -> None: from torchrec.fx import symbolic_trace - gm = symbolic_trace(qebc) + gm = symbolic_trace(qebc, leaf_modules=[ComputeKJTToJTDict.__name__]) non_placeholder_nodes = [ node for node in gm.graph.nodes if node.op != "placeholder" @@ -455,13 +459,13 @@ def test_trace_and_script(self) -> None: ) self.assertEqual( non_placeholder_nodes[0].op, - "call_function", - f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead", + "call_module", + f"First non-placeholder node must be call_method, got {non_placeholder_nodes[0].op} instead", ) self.assertEqual( non_placeholder_nodes[0].name, - features_to_dict.__name__, - f"First non-placeholder node must be features_to_dict, got {non_placeholder_nodes[0].name} instead", + "_kjt_to_jt_dict", + f"First non-placeholder node must be _kjt_to_jt_dict, got {non_placeholder_nodes[0].name} instead", ) features = KeyedJaggedTensor(