From ef7e511267260b9eb8ed3b4c4338455dd54931d5 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 18 Apr 2024 13:34:57 -0700 Subject: [PATCH] Fix pytree flatten/unflatten for KeyedTensor (#1899) Summary: KeyedTensor length_per_key represents size of embeddings after lookup, which are static. Reflecting that in the pytree case to help torch.export understand constraints instead of allocating new unbacked SymInts Differential Revision: D56320582 --- torchrec/sparse/jagged_tensor.py | 14 +++++++++----- torchrec/sparse/tests/test_jagged_tensor.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 3670c3030..5f8963905 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2519,17 +2519,21 @@ def __str__(self) -> str: def _kt_flatten( kt: KeyedTensor, -) -> Tuple[List[torch.Tensor], List[str]]: - return [torch.tensor(kt._length_per_key, dtype=torch.int64), kt._values], kt._keys +) -> Tuple[List[torch.Tensor], Tuple[List[str], List[int]]]: + return [kt._values], (kt._keys, kt._length_per_key) -def _kt_unflatten(values: List[torch.Tensor], context: List[str]) -> KeyedTensor: - return KeyedTensor(context, values[0].tolist(), values[1]) +def _kt_unflatten( + values: List[torch.Tensor], context: Tuple[List[str], List[int]] +) -> KeyedTensor: + return KeyedTensor(context[0], context[1], values[0]) def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: return _kt_flatten(kt)[0] -register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten) +register_pytree_node( + KeyedTensor, _kt_flatten, _kt_unflatten, serialized_type_name="KeyedTensor" +) register_pytree_flatten_spec(KeyedTensor, _kt_flatten_spec) diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 614ce5965..67b6d77c8 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2289,6 +2289,23 @@ def test_string_values(self) -> None: """, ) + def test_pytree(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]), + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) + + flattened, out_spec = pytree.tree_flatten(kt) + + self.assertTrue(torch.equal(flattened[0], kt.values())) + unflattened = pytree.tree_unflatten(flattened, out_spec) + + self.assertTrue(isinstance(unflattened, KeyedTensor)) + self.assertListEqual(unflattened.keys(), keys) + self.assertListEqual(unflattened._length_per_key, kt._length_per_key) + class TestComputeKJTToJTDict(unittest.TestCase): def test_key_lookup(self) -> None: