diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 54ab551e1..21a2f351c 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -216,6 +216,35 @@ def forward(self, kjt: KeyedJaggedTensor): test_pt2_ir_export=True, ) + # pyre-ignore + def test_kjt__getitem__batch_size_lengths(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, sz: torch.Tensor): + # Have to manipulate the size here instead of the input + # for torch.export to record as SymInt + sz_int = sz.item() + new_kjt = KeyedJaggedTensor( + keys=kjt.keys(), + # batch_size sz, doesn't matter what it is, get's turned into SymInt + values=torch.ones(sz_int, dtype=torch.int32).repeat(256), + lengths=torch.ones(sz_int, dtype=torch.int32).repeat(256), + ) + + return new_kjt["key0"] + + # 3 keys for a compound expression for stride, (2 * length_numel) / 7 + kjt: KeyedJaggedTensor = make_kjt([1] * 447, [1] * 447) + sz: torch.Tensor = torch.tensor([447], dtype=torch.int32) + + self._test_kjt_input_module( + M(), + kjt.keys(), + (kjt._values, kjt._lengths, sz), + test_dynamo=False, + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + # pyre-ignores @unittest.skipIf( torch.cuda.device_count() <= 1, diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 0e969b8e4..eb923de0e 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -14,6 +14,7 @@ import torch from torch.autograd.profiler import record_function from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node from torchrec.streamable import Pipelineable @@ -682,13 +683,14 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten def _assert_tensor_has_no_elements_or_has_integers( tensor: torch.Tensor, tensor_name: str ) -> None: - assert tensor.numel() == 0 or tensor.dtype in [ - torch.long, - torch.int, - torch.short, - torch.int8, - torch.uint8, - ], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype) + if not is_non_strict_exporting(): + assert tensor.numel() == 0 or tensor.dtype in [ + torch.long, + torch.int, + torch.short, + torch.int8, + torch.uint8, + ], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype) def _maybe_compute_index_per_key( @@ -1298,6 +1300,12 @@ def __init__( _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") if lengths is not None: _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") + if ( + not torch.jit.is_scripting() + and is_non_strict_exporting() + and len(keys) > 0 + ): + torch._check(lengths.numel() % len(keys) == 0) self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets @@ -1427,7 +1435,9 @@ def concat( elif stride is None: stride = kjt.stride() else: - assert stride == kjt.stride(), "strides must be consistent for all KJTs" + torch._check( + stride == kjt.stride(), "strides must be consistent for all KJTs" + ) return KeyedJaggedTensor( keys=keys,