From 86a965c3402e23dc87ae401c7347581df8dfc26c Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Wed, 10 Apr 2024 12:18:03 -0700 Subject: [PATCH] batch_size * lengths for export (#1864) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1864 Differential Revision: D55602984 --- torchrec/distributed/tests/test_pt2.py | 29 ++++++++++++++++++++++++++ torchrec/sparse/jagged_tensor.py | 24 +++++++++++++-------- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 54ab551e1..21a2f351c 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -216,6 +216,35 @@ def forward(self, kjt: KeyedJaggedTensor): test_pt2_ir_export=True, ) + # pyre-ignore + def test_kjt__getitem__batch_size_lengths(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, sz: torch.Tensor): + # Have to manipulate the size here instead of the input + # for torch.export to record as SymInt + sz_int = sz.item() + new_kjt = KeyedJaggedTensor( + keys=kjt.keys(), + # batch_size sz, doesn't matter what it is, get's turned into SymInt + values=torch.ones(sz_int, dtype=torch.int32).repeat(256), + lengths=torch.ones(sz_int, dtype=torch.int32).repeat(256), + ) + + return new_kjt["key0"] + + # 3 keys for a compound expression for stride, (2 * length_numel) / 7 + kjt: KeyedJaggedTensor = make_kjt([1] * 447, [1] * 447) + sz: torch.Tensor = torch.tensor([447], dtype=torch.int32) + + self._test_kjt_input_module( + M(), + kjt.keys(), + (kjt._values, kjt._lengths, sz), + test_dynamo=False, + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + # pyre-ignores @unittest.skipIf( torch.cuda.device_count() <= 1, diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 0e969b8e4..2b3c96f02 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -14,6 +14,7 @@ import torch from torch.autograd.profiler import record_function from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node from torchrec.streamable import Pipelineable @@ -682,13 +683,14 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten def _assert_tensor_has_no_elements_or_has_integers( tensor: torch.Tensor, tensor_name: str ) -> None: - assert tensor.numel() == 0 or tensor.dtype in [ - torch.long, - torch.int, - torch.short, - torch.int8, - torch.uint8, - ], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype) + if not is_non_strict_exporting(): + assert tensor.numel() == 0 or tensor.dtype in [ + torch.long, + torch.int, + torch.short, + torch.int8, + torch.uint8, + ], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype) def _maybe_compute_index_per_key( @@ -803,7 +805,7 @@ def _maybe_compute_length_per_key( if variable_stride_per_key else ( torch.sum(lengths.view(-1, stride), dim=1).tolist() - if lengths.numel() != 0 + if guard_size_oblivious(lengths.numel() != 0) else [0] * len(keys) ) ) @@ -1298,6 +1300,8 @@ def __init__( _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") if lengths is not None: _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") + if not torch.jit.is_scripting() and is_non_strict_exporting() and len(keys) > 0: + torch._check(lengths.numel() % len(keys) == 0) self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets @@ -1427,7 +1431,9 @@ def concat( elif stride is None: stride = kjt.stride() else: - assert stride == kjt.stride(), "strides must be consistent for all KJTs" + torch._check( + stride == kjt.stride(), "strides must be consistent for all KJTs" + ) return KeyedJaggedTensor( keys=keys,