diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 3670c3030..5f8963905 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2519,17 +2519,21 @@ def __str__(self) -> str: def _kt_flatten( kt: KeyedTensor, -) -> Tuple[List[torch.Tensor], List[str]]: - return [torch.tensor(kt._length_per_key, dtype=torch.int64), kt._values], kt._keys +) -> Tuple[List[torch.Tensor], Tuple[List[str], List[int]]]: + return [kt._values], (kt._keys, kt._length_per_key) -def _kt_unflatten(values: List[torch.Tensor], context: List[str]) -> KeyedTensor: - return KeyedTensor(context, values[0].tolist(), values[1]) +def _kt_unflatten( + values: List[torch.Tensor], context: Tuple[List[str], List[int]] +) -> KeyedTensor: + return KeyedTensor(context[0], context[1], values[0]) def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]: return _kt_flatten(kt)[0] -register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten) +register_pytree_node( + KeyedTensor, _kt_flatten, _kt_unflatten, serialized_type_name="KeyedTensor" +) register_pytree_flatten_spec(KeyedTensor, _kt_flatten_spec) diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 614ce5965..67b6d77c8 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2289,6 +2289,23 @@ def test_string_values(self) -> None: """, ) + def test_pytree(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]), + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) + + flattened, out_spec = pytree.tree_flatten(kt) + + self.assertTrue(torch.equal(flattened[0], kt.values())) + unflattened = pytree.tree_unflatten(flattened, out_spec) + + self.assertTrue(isinstance(unflattened, KeyedTensor)) + self.assertListEqual(unflattened.keys(), keys) + self.assertListEqual(unflattened._length_per_key, kt._length_per_key) + class TestComputeKJTToJTDict(unittest.TestCase): def test_key_lookup(self) -> None: