Skip to content

Commit

Permalink
remove features_to_dict
Browse files Browse the repository at this point in the history
Summary:
Use _maybe_compute_kjt_to_jt_dict instead of features_to_dict in the tagging rule.

fx wrapping features_to_dict lead to KJT storing _jt_dict inside its body, and causing JaggedTensor deserialization issue for MRS models

Differential Revision: D55842584
  • Loading branch information
seanx92 authored and facebook-github-bot committed Apr 7, 2024
1 parent f43660b commit 306d724
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
9 changes: 1 addition & 8 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,6 @@ def _update_embedding_configs(
)


@torch.fx.wrap
def features_to_dict(
features: KeyedJaggedTensor,
) -> Dict[str, JaggedTensor]:
return features.to_dict()


class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin):
"""
EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).
Expand Down Expand Up @@ -463,7 +456,7 @@ def forward(
KeyedTensor
"""

feature_dict = features_to_dict(features)
feature_dict = features.to_dict()
embeddings = []

# TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script.
Expand Down
9 changes: 4 additions & 5 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
EmbeddingCollection as QuantEmbeddingCollection,
features_to_dict,
quant_prep_enable_quant_state_dict_split_scale_bias,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
Expand Down Expand Up @@ -455,13 +454,13 @@ def test_trace_and_script(self) -> None:
)
self.assertEqual(
non_placeholder_nodes[0].op,
"call_function",
f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead",
"call_method",
f"First non-placeholder node must be call_method, got {non_placeholder_nodes[0].op} instead",
)
self.assertEqual(
non_placeholder_nodes[0].name,
features_to_dict.__name__,
f"First non-placeholder node must be features_to_dict, got {non_placeholder_nodes[0].name} instead",
"to_dict",
f"First non-placeholder node must be to_dict, got {non_placeholder_nodes[0].name} instead",
)

features = KeyedJaggedTensor(
Expand Down

0 comments on commit 306d724

Please sign in to comment.