diff --git a/test/inductor/test_custom_lowering.py b/test/inductor/test_custom_lowering.py new file mode 100644 index 0000000000000..6740108298fa9 --- /dev/null +++ b/test/inductor/test_custom_lowering.py @@ -0,0 +1,145 @@ +# Owner(s): ["module: inductor"] + +import unittest + +import torch + +from torch._inductor.ir import Pointwise +from torch._inductor.lowering import register_lowering +from torch._inductor.virtualized import ops + +from torch.testing._internal.common_utils import TestCase as TorchTestCase +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA + + +# These tests check issues for lowerings that aren't in the main pytorch repo +class TestCustomLowering(TorchTestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.test_inductor_ops = torch.library.Library("test_inductor_ops", "DEF") + cls.impl_cuda = torch.library.Library("test_inductor_ops", "IMPL", "CUDA") + cls.impl_meta = torch.library.Library("test_inductor_ops", "IMPL", "Meta") + cls._register_jagged_to_padded_dense() + + @classmethod + def tearDown(cls): + super().tearDownClass() + + @classmethod + def _register_jagged_to_padded_dense(cls): + # Approximation of fbgemm.jagged_to_padded_dense_forward + cls.test_inductor_ops.define( + "jagged_to_padded_dense(Tensor input, Tensor offsets, SymInt max_seq_len, Scalar pad_value) -> Tensor" + ) + + def j2pd_meta(inp, offsets, max_seq_len, pad_value): + return torch.empty( + (offsets.shape[0] - 1, max_seq_len, inp.shape[1]), + device=inp.device, + dtype=inp.dtype, + ) + + def j2pd_cuda(inp, offsets, max_seq_len, pad_value): + res = torch.full( + (offsets.shape[0] - 1, max_seq_len, inp.shape[1]), + pad_value, + device=inp.device, + dtype=inp.dtype, + ) + for b in range(offsets.shape[0] - 1): + for r in range(offsets[b + 1] - offsets[b]): + res[b][r] = inp[offsets[b] + r] + return res + + def j2pd_lowering(inp, offsets, max_seq_len, pad_value): + offsets_loader = offsets.make_loader() + inp_loader = inp.make_loader() + jagged_len = inp.get_size()[0] + offsets_dtype = offsets.get_dtype() + + def inner_fn(index): + batch_idx, seq_idx, emb_idx = index + + begin_idx = ops.indirect_indexing( + offsets_loader([batch_idx]), + jagged_len + 1, + ) + end_idx = offsets_loader([batch_idx + 1]) + jagged_idx = begin_idx + seq_idx + + return ops.masked( + ops.lt( + ops.index_expr(jagged_idx, offsets_dtype), + end_idx, + ), + lambda: inp_loader([jagged_idx, emb_idx]), + pad_value, + ) + + return Pointwise.create( + device=inp.get_device(), + dtype=inp.get_dtype(), + inner_fn=inner_fn, + ranges=[offsets.get_size()[0] - 1, max_seq_len, inp.get_size()[1]], + ) + + register_lowering( + torch.ops.test_inductor_ops.jagged_to_padded_dense, type_promotion_kind=None + )(j2pd_lowering) + + cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta) + cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda) + + @unittest.skipIf(not HAS_CUDA, "CUDA needed") + def test_jagged_to_padded_dense_sanity_cuda(self): + def fn(inp, offsets, max_seq_len): + return torch.ops.test_inductor_ops.jagged_to_padded_dense( + inp, offsets, max_seq_len, 60.0 + ) + + inp = torch.rand((9, 96), device="cuda") + offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda") + max_seq_len = 4 + + res = fn(inp, offsets, max_seq_len) + self.assertEqual(inp[0], res[0][0]) + self.assertEqual(inp[1], res[0][1]) + self.assertEqual(inp[2], res[1][0]) + self.assertEqual(inp[3], res[1][1]) + self.assertEqual(inp[5], res[2][0]) + self.assertEqual(inp[8], res[2][3]) + + fn_opt = torch.compile(fn) + + self.assertEqual( + fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) + ) + + @unittest.skipIf(not HAS_CUDA, "CUDA needed") + def test_jagged_to_padded_dense_zero_size(self): + # Previously, the masking was being completely stripped for the + # masked load of the input value. That would lead to an IMA + # because cuda was trying to read index 0 of a zero-size tensor. + def fn(inp, offsets, max_seq_len): + inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1)) + return torch.ops.test_inductor_ops.jagged_to_padded_dense( + inp, offsets, max_seq_len, 60.0 + ) + + inp = torch.rand((1, 0, 96), device="cuda") + offsets = torch.zeros(1025, device="cuda", dtype=torch.int32) + max_seq_len = 20 + + fn_opt = torch.compile(fn) + + self.assertEqual( + fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + if HAS_CPU or HAS_CUDA: + run_tests(needs="filelock") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0117039256a05..dbced19ce2726 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -302,6 +302,9 @@ def get_read_names(self): def get_numel(self): return sympy_product(self.get_size()) + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) + def realize(self): """ If the IRNode refers to data which has not been materialized (e.g., @@ -386,9 +389,6 @@ def inner_fn_str(self): index = self._index(self.ranges) return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index) - def is_zero_elements(self): - return any(r == 0 for r in self.ranges) - def get_reads(self): with patch.object(FlexibleLayout, "allow_indexing", True): if self.get_reduction_type(): @@ -404,8 +404,19 @@ def get_reads(self): ).reads +def nop_loader_fn(idx, *, dtype): + if dtype.is_floating_point: + return ops.constant(float("nan"), dtype) + else: + return ops.constant(0, dtype) + + class Pointwise(Loops): def make_loader(self): + # Make zero-element loops into a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.dtype) + return self.inner_fn def get_reduction_size(self): @@ -415,7 +426,8 @@ def get_reduction_type(self): return None def store_output(self, output_name, indexer, vars): - return ops.store(output_name, indexer(vars), self.inner_fn(vars)) + loader = self.make_loader() + return ops.store(output_name, indexer(vars), loader(vars)) def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" @@ -443,10 +455,11 @@ def constant_to_device(self, device): ) def store_output(self, output_name, indexer, vars): + loader = self.make_loader() return ops.store( output_name, indexer(self.output_indexer(vars)), - self.inner_fn(vars), + loader(vars), mode=self.scatter_mode, ) @@ -2088,7 +2101,14 @@ def freeze_layout_with_same_order(self, stride): assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_same_order(stride) + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) + def make_loader(self): + # Loading from a zero-element buffer is a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.get_dtype()) + def loader(index): indexer = self.layout.make_indexer() return ops.load(self.name, indexer(index))