diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 25d33e765..54ab551e1 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -163,6 +163,7 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]): kjt.keys(), (kjt._values, kjt._lengths, indices), test_aot_inductor=False, + test_pt2_ir_export=True, ) def test_kjt_length_per_key(self) -> None: diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index f57686f5a..6f88dc464 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1831,7 +1831,7 @@ def permute( permuted_keys: List[str] = [] permuted_stride_per_key_per_rank: List[List[int]] = [] permuted_length_per_key: List[int] = [] - permuted_lengths_sum = 0 + permuted_length_per_key_sum = 0 for index in indices: key = self.keys()[index] permuted_keys.append(key) @@ -1839,7 +1839,8 @@ def permute( self.stride_per_key_per_rank()[index] ) permuted_length_per_key.append(length_per_key[index]) - permuted_lengths_sum += length_per_key[index] + if not is_non_strict_exporting(): + permuted_length_per_key_sum += length_per_key[index] if self.variable_stride_per_key(): length_per_key_tensor = _pin_and_move( torch.tensor(self.length_per_key()), self.device() @@ -1860,6 +1861,20 @@ def permute( self.weights_or_none(), ) else: + if not torch.jit.is_scripting() and is_non_strict_exporting(): + permuted_length_per_key_sum = torch.sum( + torch._refs.tensor( + permuted_length_per_key, + dtype=torch.int32, + device=torch.device("cpu"), + pin_memory=False, + requires_grad=False, + ) + ).item() + + torch._check(permuted_length_per_key_sum >= 0) + torch._check(permuted_length_per_key_sum != 0) + ( permuted_lengths, permuted_values, @@ -1869,7 +1884,7 @@ def permute( self.lengths().view(len(self._keys), -1), self.values(), self.weights_or_none(), - permuted_lengths_sum, + permuted_length_per_key_sum, ) stride, optional_permuted_stride_per_key_per_rank = ( (None, permuted_stride_per_key_per_rank)