diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 0f3b8b439..fe0ba5457 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -230,13 +230,6 @@ def _update_embedding_configs( ) -@torch.fx.wrap -def features_to_dict( - features: KeyedJaggedTensor, -) -> Dict[str, JaggedTensor]: - return features.to_dict() - - class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin): """ EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags). @@ -463,7 +456,7 @@ def forward( KeyedTensor """ - feature_dict = features_to_dict(features) + feature_dict = features.to_dict() embeddings = [] # TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script. diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py index 3b5db8233..2dce797a9 100644 --- a/torchrec/quant/tests/test_embedding_modules.py +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -29,7 +29,6 @@ from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, EmbeddingCollection as QuantEmbeddingCollection, - features_to_dict, quant_prep_enable_quant_state_dict_split_scale_bias, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor @@ -455,13 +454,13 @@ def test_trace_and_script(self) -> None: ) self.assertEqual( non_placeholder_nodes[0].op, - "call_function", - f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead", + "call_method", + f"First non-placeholder node must be call_method, got {non_placeholder_nodes[0].op} instead", ) self.assertEqual( non_placeholder_nodes[0].name, - features_to_dict.__name__, - f"First non-placeholder node must be features_to_dict, got {non_placeholder_nodes[0].name} instead", + "to_dict", + f"First non-placeholder node must be to_dict, got {non_placeholder_nodes[0].name} instead", ) features = KeyedJaggedTensor(