Skip to content

Commit

Permalink
tolist() support for FunctionalTensor
Browse files Browse the repository at this point in the history
Summary: Support tolist() for FunctionalTensor for KJT in torch.export

Reviewed By: ezyang

Differential Revision: D53731064
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Feb 24, 2024
1 parent 9d59061 commit bbd63a9
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,18 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None:
# TODO: turn on AOT Inductor test once the support is ready
test_aot_inductor=False,
)

def test_tensor_tolist(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor):
return kjt.values().tolist()

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
test_dynamo=False,
test_aot_inductor=False,
test_pt2_ir_export=True,
)

0 comments on commit bbd63a9

Please sign in to comment.