From d9ef776dcedc1f3013cb97287d6a0949a92a4b89 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 22 Mar 2024 11:16:07 -0700 Subject: [PATCH] KJT split torch export support (#1816) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1816 torch.export support for KJT split Reviewed By: IvanKobzarev Differential Revision: D53545161 fbshipit-source-id: 7afbf8843479873251333ef7cffdb4081a12cb06 --- torchrec/distributed/tests/test_pt2.py | 8 ++- torchrec/sparse/jagged_tensor.py | 97 ++++++++++++++++++-------- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index ec35fc0c4..25d33e765 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -137,16 +137,18 @@ def _test_kjt_input_module( def test_kjt_split(self) -> None: class M(torch.nn.Module): - def forward(self, kjt: KeyedJaggedTensor, segments: List[int]): - return kjt.split(segments) + def forward(self, kjt: KeyedJaggedTensor): + return kjt.split([1, 2, 1]) kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) segments: List[int] = [1, 2, 1] self._test_kjt_input_module( M(), kjt.keys(), - (kjt._values, kjt._lengths, segments), + (kjt._values, kjt._lengths), test_aot_inductor=False, + test_dynamo=False, + test_pt2_ir_export=True, ) def test_kjt_permute(self) -> None: diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index fda71a3e6..63a1f543f 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1748,39 +1748,76 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: else: split_length_per_key = _length_per_key[start:end] - if not torch.jit.is_scripting() and is_torchdynamo_compiling(): - # Checks for dynamo dynamic shapes tracing - torch._check_is_size(start_offset) - torch._check_is_size(end_offset) - torch._check_is_size(end_offset - start_offset) + if not torch.jit.is_scripting() and is_non_strict_exporting(): + sz = sum(split_length_per_key) + + [torch._check_is_size(length) for length in split_length_per_key] torch._check(start_offset <= self._values.size(0)) - torch._check(end_offset <= self._values.size(0)) - torch._check(end_offset >= start_offset) + torch._check(sz <= self._values.size(0)) + torch._check_is_size(start_offset) - split_list.append( - KeyedJaggedTensor( - keys=keys, - values=self._values[start_offset:end_offset], - weights=( - None - if self.weights_or_none() is None - else self.weights()[start_offset:end_offset] - ), - lengths=self.lengths()[ - self.lengths_offset_per_key()[ - start - ] : self.lengths_offset_per_key()[end] - ], - offsets=None, - stride=stride, - stride_per_key_per_rank=stride_per_key_per_rank, - length_per_key=split_length_per_key, - offset_per_key=None, - index_per_key=None, - jt_dict=None, - inverse_indices=None, + torch._check(start_offset + sz <= self._values.size(0)) + + lengths_start = self.lengths_offset_per_key()[start] + lengths_sz = self.lengths_offset_per_key()[end] - lengths_start + + _lengths = torch.narrow( + self.lengths(), 0, lengths_start, lengths_sz + ) + split_list.append( + KeyedJaggedTensor( + keys=keys, + values=torch.narrow(self._values, 0, start_offset, sz), + weights=( + None + if self.weights_or_none() is None + else torch.narrow(self.weights(), 0, start_offset, sz) + ), + lengths=_lengths, + offsets=None, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=split_length_per_key, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + inverse_indices=None, + ) + ) + else: + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + # Checks for dynamo dynamic shapes tracing + torch._check_is_size(start_offset) + torch._check_is_size(end_offset) + torch._check_is_size(end_offset - start_offset) + torch._check(start_offset <= self._values.size(0)) + torch._check(end_offset <= self._values.size(0)) + torch._check(end_offset >= start_offset) + + split_list.append( + KeyedJaggedTensor( + keys=keys, + values=self._values[start_offset:end_offset], + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), + lengths=self.lengths()[ + self.lengths_offset_per_key()[ + start + ] : self.lengths_offset_per_key()[end] + ], + offsets=None, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=split_length_per_key, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + inverse_indices=None, + ) ) - ) start = end start_offset = end_offset return split_list