forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] Optimize away zero-element loads (pytorch#107074)
Fixes pytorch#107066, closes pytorch#107008 This replaces loads to zero-element `Loops` or `Buffer`s with `ops.constant` calls. This both avoids the issue of masked loads under triton, and also means the buffer is not listed as a dependency for downstream users which may improve performance generally. Pull Request resolved: pytorch#107074 Approved by: https://github.com/davidberard98
- Loading branch information
1 parent
aa36e16
commit 8472c24
Showing
2 changed files
with
170 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Owner(s): ["module: inductor"] | ||
|
||
import unittest | ||
|
||
import torch | ||
|
||
from torch._inductor.ir import Pointwise | ||
from torch._inductor.lowering import register_lowering | ||
from torch._inductor.virtualized import ops | ||
|
||
from torch.testing._internal.common_utils import TestCase as TorchTestCase | ||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA | ||
|
||
|
||
# These tests check issues for lowerings that aren't in the main pytorch repo | ||
class TestCustomLowering(TorchTestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.test_inductor_ops = torch.library.Library("test_inductor_ops", "DEF") | ||
cls.impl_cuda = torch.library.Library("test_inductor_ops", "IMPL", "CUDA") | ||
cls.impl_meta = torch.library.Library("test_inductor_ops", "IMPL", "Meta") | ||
cls._register_jagged_to_padded_dense() | ||
|
||
@classmethod | ||
def tearDown(cls): | ||
super().tearDownClass() | ||
|
||
@classmethod | ||
def _register_jagged_to_padded_dense(cls): | ||
# Approximation of fbgemm.jagged_to_padded_dense_forward | ||
cls.test_inductor_ops.define( | ||
"jagged_to_padded_dense(Tensor input, Tensor offsets, SymInt max_seq_len, Scalar pad_value) -> Tensor" | ||
) | ||
|
||
def j2pd_meta(inp, offsets, max_seq_len, pad_value): | ||
return torch.empty( | ||
(offsets.shape[0] - 1, max_seq_len, inp.shape[1]), | ||
device=inp.device, | ||
dtype=inp.dtype, | ||
) | ||
|
||
def j2pd_cuda(inp, offsets, max_seq_len, pad_value): | ||
res = torch.full( | ||
(offsets.shape[0] - 1, max_seq_len, inp.shape[1]), | ||
pad_value, | ||
device=inp.device, | ||
dtype=inp.dtype, | ||
) | ||
for b in range(offsets.shape[0] - 1): | ||
for r in range(offsets[b + 1] - offsets[b]): | ||
res[b][r] = inp[offsets[b] + r] | ||
return res | ||
|
||
def j2pd_lowering(inp, offsets, max_seq_len, pad_value): | ||
offsets_loader = offsets.make_loader() | ||
inp_loader = inp.make_loader() | ||
jagged_len = inp.get_size()[0] | ||
offsets_dtype = offsets.get_dtype() | ||
|
||
def inner_fn(index): | ||
batch_idx, seq_idx, emb_idx = index | ||
|
||
begin_idx = ops.indirect_indexing( | ||
offsets_loader([batch_idx]), | ||
jagged_len + 1, | ||
) | ||
end_idx = offsets_loader([batch_idx + 1]) | ||
jagged_idx = begin_idx + seq_idx | ||
|
||
return ops.masked( | ||
ops.lt( | ||
ops.index_expr(jagged_idx, offsets_dtype), | ||
end_idx, | ||
), | ||
lambda: inp_loader([jagged_idx, emb_idx]), | ||
pad_value, | ||
) | ||
|
||
return Pointwise.create( | ||
device=inp.get_device(), | ||
dtype=inp.get_dtype(), | ||
inner_fn=inner_fn, | ||
ranges=[offsets.get_size()[0] - 1, max_seq_len, inp.get_size()[1]], | ||
) | ||
|
||
register_lowering( | ||
torch.ops.test_inductor_ops.jagged_to_padded_dense, type_promotion_kind=None | ||
)(j2pd_lowering) | ||
|
||
cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta) | ||
cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda) | ||
|
||
@unittest.skipIf(not HAS_CUDA, "CUDA needed") | ||
def test_jagged_to_padded_dense_sanity_cuda(self): | ||
def fn(inp, offsets, max_seq_len): | ||
return torch.ops.test_inductor_ops.jagged_to_padded_dense( | ||
inp, offsets, max_seq_len, 60.0 | ||
) | ||
|
||
inp = torch.rand((9, 96), device="cuda") | ||
offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda") | ||
max_seq_len = 4 | ||
|
||
res = fn(inp, offsets, max_seq_len) | ||
self.assertEqual(inp[0], res[0][0]) | ||
self.assertEqual(inp[1], res[0][1]) | ||
self.assertEqual(inp[2], res[1][0]) | ||
self.assertEqual(inp[3], res[1][1]) | ||
self.assertEqual(inp[5], res[2][0]) | ||
self.assertEqual(inp[8], res[2][3]) | ||
|
||
fn_opt = torch.compile(fn) | ||
|
||
self.assertEqual( | ||
fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) | ||
) | ||
|
||
@unittest.skipIf(not HAS_CUDA, "CUDA needed") | ||
def test_jagged_to_padded_dense_zero_size(self): | ||
# Previously, the masking was being completely stripped for the | ||
# masked load of the input value. That would lead to an IMA | ||
# because cuda was trying to read index 0 of a zero-size tensor. | ||
def fn(inp, offsets, max_seq_len): | ||
inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1)) | ||
return torch.ops.test_inductor_ops.jagged_to_padded_dense( | ||
inp, offsets, max_seq_len, 60.0 | ||
) | ||
|
||
inp = torch.rand((1, 0, 96), device="cuda") | ||
offsets = torch.zeros(1025, device="cuda", dtype=torch.int32) | ||
max_seq_len = 20 | ||
|
||
fn_opt = torch.compile(fn) | ||
|
||
self.assertEqual( | ||
fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
||
if HAS_CPU or HAS_CUDA: | ||
run_tests(needs="filelock") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters