Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch_size * lengths for export #1864

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,35 @@ def forward(self, kjt: KeyedJaggedTensor):
test_pt2_ir_export=True,
)

# pyre-ignore
def test_kjt__getitem__batch_size_lengths(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor, sz: torch.Tensor):
# Have to manipulate the size here instead of the input
# for torch.export to record as SymInt
sz_int = sz.item()
new_kjt = KeyedJaggedTensor(
keys=kjt.keys(),
# batch_size sz, doesn't matter what it is, get's turned into SymInt
values=torch.ones(sz_int, dtype=torch.int32).repeat(256),
lengths=torch.ones(sz_int, dtype=torch.int32).repeat(256),
)

return new_kjt["key0"]

# 3 keys for a compound expression for stride, (2 * length_numel) / 7
kjt: KeyedJaggedTensor = make_kjt([1] * 447, [1] * 447)
sz: torch.Tensor = torch.tensor([447], dtype=torch.int32)

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths, sz),
test_dynamo=False,
test_aot_inductor=False,
test_pt2_ir_export=True,
)

# pyre-ignores
@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down
58 changes: 42 additions & 16 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from torch.autograd.profiler import record_function
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node

from torchrec.streamable import Pipelineable
Expand Down Expand Up @@ -682,13 +683,14 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
def _assert_tensor_has_no_elements_or_has_integers(
tensor: torch.Tensor, tensor_name: str
) -> None:
assert tensor.numel() == 0 or tensor.dtype in [
torch.long,
torch.int,
torch.short,
torch.int8,
torch.uint8,
], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype)
if not is_non_strict_exporting():
assert tensor.numel() == 0 or tensor.dtype in [
torch.long,
torch.int,
torch.short,
torch.int8,
torch.uint8,
], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype)


def _maybe_compute_index_per_key(
Expand Down Expand Up @@ -798,15 +800,26 @@ def _maybe_compute_length_per_key(
else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
elif len(keys) and lengths is not None:
_length: List[int] = (
_length_per_key_from_stride_per_key(lengths, stride_per_key)
if variable_stride_per_key
else (
torch.sum(lengths.view(-1, stride), dim=1).tolist()
if lengths.numel() != 0
else [0] * len(keys)
if not torch.jit.is_scripting() and is_non_strict_exporting():
_length: List[int] = (
_length_per_key_from_stride_per_key(lengths, stride_per_key)
if variable_stride_per_key
else (
torch.sum(lengths.view(-1, stride), dim=1).tolist()
if guard_size_oblivious(lengths.numel() != 0)
else [0] * len(keys)
)
)
else:
_length: List[int] = (
_length_per_key_from_stride_per_key(lengths, stride_per_key)
if variable_stride_per_key
else (
torch.sum(lengths.view(-1, stride), dim=1).tolist()
if lengths.numel() != 0
else [0] * len(keys)
)
)
)
else:
_length: List[int] = []
length_per_key = _length
Expand Down Expand Up @@ -1298,6 +1311,12 @@ def __init__(
_assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
if lengths is not None:
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
if (
not torch.jit.is_scripting()
and is_non_strict_exporting()
and len(keys) > 0
):
torch._check(lengths.numel() % len(keys) == 0)
self._lengths: Optional[torch.Tensor] = lengths
self._offsets: Optional[torch.Tensor] = offsets

Expand Down Expand Up @@ -1427,7 +1446,14 @@ def concat(
elif stride is None:
stride = kjt.stride()
else:
assert stride == kjt.stride(), "strides must be consistent for all KJTs"
if not torch.jit.is_scripting():
# torch._check when jit scripting: Unknown builtin op: aten::_check.
torch._check(
stride == kjt.stride(),
"strides must be consistent for all KJTs",
)
else:
assert stride == kjt.stride()

return KeyedJaggedTensor(
keys=keys,
Expand Down
Loading