From 49170d7f472d52a1e6df7311a51eb9c916d7138b Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Sun, 14 Apr 2024 06:23:02 -0700 Subject: [PATCH] Replace to_dict to permute in QEBC Differential Revision: D56069966 --- torchrec/quant/embedding_modules.py | 45 ++++++++----------- .../quant/tests/test_embedding_modules.py | 5 ++- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 431ce406e..023d80113 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -320,6 +320,8 @@ def __init__( self._key_to_tables: Dict[ Tuple[PoolingType, DataType], List[EmbeddingBagConfig] ] = defaultdict(list) + self._feature_names: List[str] = [] + self._feature_splits: List[int] = [] self._length_per_key: List[int] = [] # Registering in a List instead of ModuleList because we want don't want them to be auto-registered. # Their states will be modified via self.embedding_bags @@ -389,6 +391,11 @@ def __init__( if weight_lists is None: emb_module.initialize_weights() self._emb_modules.append(emb_module) + for table in emb_configs: + self._feature_names.extend(table.feature_names) + self._feature_splits.append( + sum(table.num_features() for table in emb_configs) + ) ordered_tables = list(itertools.chain(*self._key_to_tables.values())) self._embedding_names: List[str] = list( @@ -462,47 +469,31 @@ def forward( KeyedTensor """ - feature_dict = self._kjt_to_jt_dict(features) embeddings = [] + kjt_keys = features.keys() + kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] + kjt_permute = features.permute(kjt_permute_order) + kjts_per_key = kjt_permute.split(self._feature_splits) - # TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script. - # Once torchsccript is no longer a requirement, we should revisit this. - - for emb_op, (_key, tables) in zip( - self._emb_modules, self._key_to_tables.items() + for i, (emb_op, _) in enumerate( + zip(self._emb_modules, self._key_to_tables.keys()) ): - indices = [] - lengths = [] - offsets = [] - weights = [] - - for table in tables: - for feature in table.feature_names: - f = feature_dict[feature] - indices.append(f.values()) - lengths.append(f.lengths()) - if self._is_weighted: - weights.append(f.weights()) - - indices = torch.cat(indices) - lengths = torch.cat(lengths) - - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - if self._is_weighted: - weights = torch.cat(weights) + f = kjts_per_key[i] + indices = f.values() + offsets = f.offsets() embeddings.append( # Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module emb_op( indices=indices, offsets=offsets, - per_sample_weights=weights if self._is_weighted else None, + per_sample_weights=f.weights() if self._is_weighted else None, ) if self.register_tbes else emb_op.forward( indices=indices, offsets=offsets, - per_sample_weights=weights if self._is_weighted else None, + per_sample_weights=f.weights() if self._is_weighted else None, ) ) diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py index 5d3db13a1..73742b894 100644 --- a/torchrec/quant/tests/test_embedding_modules.py +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -457,14 +457,15 @@ def test_trace_and_script(self) -> None: self.assertTrue( len(non_placeholder_nodes) > 0, "Graph must have non-placeholder nodes" ) + self.assertEqual( non_placeholder_nodes[0].op, - "call_module", + "call_method", f"First non-placeholder node must be call_method, got {non_placeholder_nodes[0].op} instead", ) self.assertEqual( non_placeholder_nodes[0].name, - "_kjt_to_jt_dict", + "keys", f"First non-placeholder node must be _kjt_to_jt_dict, got {non_placeholder_nodes[0].name} instead", )