Skip to content

Commit

Permalink
KJT split torch export support (#1816)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1816

torch.export support for KJT split

Reviewed By: IvanKobzarev

Differential Revision: D53545161

fbshipit-source-id: 7afbf8843479873251333ef7cffdb4081a12cb06
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Mar 22, 2024
1 parent 63823f7 commit d9ef776
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 33 deletions.
8 changes: 5 additions & 3 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,18 @@ def _test_kjt_input_module(

def test_kjt_split(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor, segments: List[int]):
return kjt.split(segments)
def forward(self, kjt: KeyedJaggedTensor):
return kjt.split([1, 2, 1])

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
segments: List[int] = [1, 2, 1]
self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths, segments),
(kjt._values, kjt._lengths),
test_aot_inductor=False,
test_dynamo=False,
test_pt2_ir_export=True,
)

def test_kjt_permute(self) -> None:
Expand Down
97 changes: 67 additions & 30 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,39 +1748,76 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
else:
split_length_per_key = _length_per_key[start:end]

if not torch.jit.is_scripting() and is_torchdynamo_compiling():
# Checks for dynamo dynamic shapes tracing
torch._check_is_size(start_offset)
torch._check_is_size(end_offset)
torch._check_is_size(end_offset - start_offset)
if not torch.jit.is_scripting() and is_non_strict_exporting():
sz = sum(split_length_per_key)

[torch._check_is_size(length) for length in split_length_per_key]
torch._check(start_offset <= self._values.size(0))
torch._check(end_offset <= self._values.size(0))
torch._check(end_offset >= start_offset)
torch._check(sz <= self._values.size(0))
torch._check_is_size(start_offset)

split_list.append(
KeyedJaggedTensor(
keys=keys,
values=self._values[start_offset:end_offset],
weights=(
None
if self.weights_or_none() is None
else self.weights()[start_offset:end_offset]
),
lengths=self.lengths()[
self.lengths_offset_per_key()[
start
] : self.lengths_offset_per_key()[end]
],
offsets=None,
stride=stride,
stride_per_key_per_rank=stride_per_key_per_rank,
length_per_key=split_length_per_key,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
inverse_indices=None,
torch._check(start_offset + sz <= self._values.size(0))

lengths_start = self.lengths_offset_per_key()[start]
lengths_sz = self.lengths_offset_per_key()[end] - lengths_start

_lengths = torch.narrow(
self.lengths(), 0, lengths_start, lengths_sz
)
split_list.append(
KeyedJaggedTensor(
keys=keys,
values=torch.narrow(self._values, 0, start_offset, sz),
weights=(
None
if self.weights_or_none() is None
else torch.narrow(self.weights(), 0, start_offset, sz)
),
lengths=_lengths,
offsets=None,
stride=stride,
stride_per_key_per_rank=stride_per_key_per_rank,
length_per_key=split_length_per_key,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
inverse_indices=None,
)
)
else:
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
# Checks for dynamo dynamic shapes tracing
torch._check_is_size(start_offset)
torch._check_is_size(end_offset)
torch._check_is_size(end_offset - start_offset)
torch._check(start_offset <= self._values.size(0))
torch._check(end_offset <= self._values.size(0))
torch._check(end_offset >= start_offset)

split_list.append(
KeyedJaggedTensor(
keys=keys,
values=self._values[start_offset:end_offset],
weights=(
None
if self.weights_or_none() is None
else self.weights()[start_offset:end_offset]
),
lengths=self.lengths()[
self.lengths_offset_per_key()[
start
] : self.lengths_offset_per_key()[end]
],
offsets=None,
stride=stride,
stride_per_key_per_rank=stride_per_key_per_rank,
length_per_key=split_length_per_key,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
inverse_indices=None,
)
)
)
start = end
start_offset = end_offset
return split_list
Expand Down

0 comments on commit d9ef776

Please sign in to comment.