Skip to content

Commit

Permalink
batch_size * lengths for export (#1864)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1864

Differential Revision: D55602984
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 10, 2024
1 parent 579fe9f commit a7f9b3f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
29 changes: 29 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,35 @@ def forward(self, kjt: KeyedJaggedTensor):
test_pt2_ir_export=True,
)

# pyre-ignore
def test_kjt__getitem__batch_size_lengths(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor, sz: torch.Tensor):
# Have to manipulate the size here instead of the input
# for torch.export to record as SymInt
sz_int = sz.item()
new_kjt = KeyedJaggedTensor(
keys=kjt.keys(),
# batch_size sz, doesn't matter what it is, get's turned into SymInt
values=torch.ones(sz_int, dtype=torch.int32).repeat(256),
lengths=torch.ones(sz_int, dtype=torch.int32).repeat(256),
)

return new_kjt["key0"]

# 3 keys for a compound expression for stride, (2 * length_numel) / 7
kjt: KeyedJaggedTensor = make_kjt([1] * 447, [1] * 447)
sz: torch.Tensor = torch.tensor([447], dtype=torch.int32)

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths, sz),
test_dynamo=False,
test_aot_inductor=False,
test_pt2_ir_export=True,
)

# pyre-ignores
@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down
26 changes: 18 additions & 8 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from torch.autograd.profiler import record_function
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node

from torchrec.streamable import Pipelineable
Expand Down Expand Up @@ -682,13 +683,14 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
def _assert_tensor_has_no_elements_or_has_integers(
tensor: torch.Tensor, tensor_name: str
) -> None:
assert tensor.numel() == 0 or tensor.dtype in [
torch.long,
torch.int,
torch.short,
torch.int8,
torch.uint8,
], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype)
if not is_non_strict_exporting():
assert tensor.numel() == 0 or tensor.dtype in [
torch.long,
torch.int,
torch.short,
torch.int8,
torch.uint8,
], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype)


def _maybe_compute_index_per_key(
Expand Down Expand Up @@ -1298,6 +1300,12 @@ def __init__(
_assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
if lengths is not None:
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
if (
not torch.jit.is_scripting()
and is_non_strict_exporting()
and len(keys) > 0
):
torch._check(lengths.numel() % len(keys) == 0)
self._lengths: Optional[torch.Tensor] = lengths
self._offsets: Optional[torch.Tensor] = offsets

Expand Down Expand Up @@ -1427,7 +1435,9 @@ def concat(
elif stride is None:
stride = kjt.stride()
else:
assert stride == kjt.stride(), "strides must be consistent for all KJTs"
torch._check(
stride == kjt.stride(), "strides must be consistent for all KJTs"
)

return KeyedJaggedTensor(
keys=keys,
Expand Down

0 comments on commit a7f9b3f

Please sign in to comment.