Skip to content

Commit

Permalink
KJT permute torch export support (pytorch#1850)
Browse files Browse the repository at this point in the history
Summary:

Support non-strict torch export for KJT permute method used by Sharded TorchRec Modules

Differential Revision: D55040353
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 8, 2024
1 parent f43660b commit c5af320
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):
kjt.keys(),
(kjt._values, kjt._lengths, indices),
test_aot_inductor=False,
test_pt2_ir_export=True,
)

def test_kjt_length_per_key(self) -> None:
Expand Down
21 changes: 18 additions & 3 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,15 +1831,16 @@ def permute(
permuted_keys: List[str] = []
permuted_stride_per_key_per_rank: List[List[int]] = []
permuted_length_per_key: List[int] = []
permuted_lengths_sum = 0
permuted_length_per_key_sum = 0
for index in indices:
key = self.keys()[index]
permuted_keys.append(key)
permuted_stride_per_key_per_rank.append(
self.stride_per_key_per_rank()[index]
)
permuted_length_per_key.append(length_per_key[index])
permuted_lengths_sum += length_per_key[index]
if not is_non_strict_exporting():
permuted_length_per_key_sum += length_per_key[index]
if self.variable_stride_per_key():
length_per_key_tensor = _pin_and_move(
torch.tensor(self.length_per_key()), self.device()
Expand All @@ -1860,6 +1861,20 @@ def permute(
self.weights_or_none(),
)
else:
if not torch.jit.is_scripting() and is_non_strict_exporting():
permuted_length_per_key_sum = torch.sum(
torch._refs.tensor(
permuted_length_per_key,
dtype=torch.int32,
device=torch.device("cpu"),
pin_memory=False,
requires_grad=False,
)
).item()

torch._check(permuted_length_per_key_sum >= 0)
torch._check(permuted_length_per_key_sum != 0)

(
permuted_lengths,
permuted_values,
Expand All @@ -1869,7 +1884,7 @@ def permute(
self.lengths().view(len(self._keys), -1),
self.values(),
self.weights_or_none(),
permuted_lengths_sum,
permuted_length_per_key_sum,
)
stride, optional_permuted_stride_per_key_per_rank = (
(None, permuted_stride_per_key_per_rank)
Expand Down

0 comments on commit c5af320

Please sign in to comment.