Skip to content

Commit

Permalink
[inductor] Optimize away zero-element loads (pytorch#107074)
Browse files Browse the repository at this point in the history
Fixes pytorch#107066, closes pytorch#107008

This replaces loads to zero-element `Loops` or `Buffer`s with `ops.constant`
calls. This both avoids the issue of masked loads under triton, and also means
the buffer is not listed as a dependency for downstream users which may improve
performance generally.

Pull Request resolved: pytorch#107074
Approved by: https://github.com/davidberard98
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Aug 12, 2023
1 parent aa36e16 commit 8472c24
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 5 deletions.
145 changes: 145 additions & 0 deletions test/inductor/test_custom_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Owner(s): ["module: inductor"]

import unittest

import torch

from torch._inductor.ir import Pointwise
from torch._inductor.lowering import register_lowering
from torch._inductor.virtualized import ops

from torch.testing._internal.common_utils import TestCase as TorchTestCase
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA


# These tests check issues for lowerings that aren't in the main pytorch repo
class TestCustomLowering(TorchTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.test_inductor_ops = torch.library.Library("test_inductor_ops", "DEF")
cls.impl_cuda = torch.library.Library("test_inductor_ops", "IMPL", "CUDA")
cls.impl_meta = torch.library.Library("test_inductor_ops", "IMPL", "Meta")
cls._register_jagged_to_padded_dense()

@classmethod
def tearDown(cls):
super().tearDownClass()

@classmethod
def _register_jagged_to_padded_dense(cls):
# Approximation of fbgemm.jagged_to_padded_dense_forward
cls.test_inductor_ops.define(
"jagged_to_padded_dense(Tensor input, Tensor offsets, SymInt max_seq_len, Scalar pad_value) -> Tensor"
)

def j2pd_meta(inp, offsets, max_seq_len, pad_value):
return torch.empty(
(offsets.shape[0] - 1, max_seq_len, inp.shape[1]),
device=inp.device,
dtype=inp.dtype,
)

def j2pd_cuda(inp, offsets, max_seq_len, pad_value):
res = torch.full(
(offsets.shape[0] - 1, max_seq_len, inp.shape[1]),
pad_value,
device=inp.device,
dtype=inp.dtype,
)
for b in range(offsets.shape[0] - 1):
for r in range(offsets[b + 1] - offsets[b]):
res[b][r] = inp[offsets[b] + r]
return res

def j2pd_lowering(inp, offsets, max_seq_len, pad_value):
offsets_loader = offsets.make_loader()
inp_loader = inp.make_loader()
jagged_len = inp.get_size()[0]
offsets_dtype = offsets.get_dtype()

def inner_fn(index):
batch_idx, seq_idx, emb_idx = index

begin_idx = ops.indirect_indexing(
offsets_loader([batch_idx]),
jagged_len + 1,
)
end_idx = offsets_loader([batch_idx + 1])
jagged_idx = begin_idx + seq_idx

return ops.masked(
ops.lt(
ops.index_expr(jagged_idx, offsets_dtype),
end_idx,
),
lambda: inp_loader([jagged_idx, emb_idx]),
pad_value,
)

return Pointwise.create(
device=inp.get_device(),
dtype=inp.get_dtype(),
inner_fn=inner_fn,
ranges=[offsets.get_size()[0] - 1, max_seq_len, inp.get_size()[1]],
)

register_lowering(
torch.ops.test_inductor_ops.jagged_to_padded_dense, type_promotion_kind=None
)(j2pd_lowering)

cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta)
cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda)

@unittest.skipIf(not HAS_CUDA, "CUDA needed")
def test_jagged_to_padded_dense_sanity_cuda(self):
def fn(inp, offsets, max_seq_len):
return torch.ops.test_inductor_ops.jagged_to_padded_dense(
inp, offsets, max_seq_len, 60.0
)

inp = torch.rand((9, 96), device="cuda")
offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda")
max_seq_len = 4

res = fn(inp, offsets, max_seq_len)
self.assertEqual(inp[0], res[0][0])
self.assertEqual(inp[1], res[0][1])
self.assertEqual(inp[2], res[1][0])
self.assertEqual(inp[3], res[1][1])
self.assertEqual(inp[5], res[2][0])
self.assertEqual(inp[8], res[2][3])

fn_opt = torch.compile(fn)

self.assertEqual(
fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
)

@unittest.skipIf(not HAS_CUDA, "CUDA needed")
def test_jagged_to_padded_dense_zero_size(self):
# Previously, the masking was being completely stripped for the
# masked load of the input value. That would lead to an IMA
# because cuda was trying to read index 0 of a zero-size tensor.
def fn(inp, offsets, max_seq_len):
inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1))
return torch.ops.test_inductor_ops.jagged_to_padded_dense(
inp, offsets, max_seq_len, 60.0
)

inp = torch.rand((1, 0, 96), device="cuda")
offsets = torch.zeros(1025, device="cuda", dtype=torch.int32)
max_seq_len = 20

fn_opt = torch.compile(fn)

self.assertEqual(
fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

if HAS_CPU or HAS_CUDA:
run_tests(needs="filelock")
30 changes: 25 additions & 5 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def get_read_names(self):
def get_numel(self):
return sympy_product(self.get_size())

def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))

def realize(self):
"""
If the IRNode refers to data which has not been materialized (e.g.,
Expand Down Expand Up @@ -386,9 +389,6 @@ def inner_fn_str(self):
index = self._index(self.ranges)
return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index)

def is_zero_elements(self):
return any(r == 0 for r in self.ranges)

def get_reads(self):
with patch.object(FlexibleLayout, "allow_indexing", True):
if self.get_reduction_type():
Expand All @@ -404,8 +404,19 @@ def get_reads(self):
).reads


def nop_loader_fn(idx, *, dtype):
if dtype.is_floating_point:
return ops.constant(float("nan"), dtype)
else:
return ops.constant(0, dtype)


class Pointwise(Loops):
def make_loader(self):
# Make zero-element loops into a no-op
if self.is_zero_elements():
return partial(nop_loader_fn, dtype=self.dtype)

return self.inner_fn

def get_reduction_size(self):
Expand All @@ -415,7 +426,8 @@ def get_reduction_type(self):
return None

def store_output(self, output_name, indexer, vars):
return ops.store(output_name, indexer(vars), self.inner_fn(vars))
loader = self.make_loader()
return ops.store(output_name, indexer(vars), loader(vars))

def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
Expand Down Expand Up @@ -443,10 +455,11 @@ def constant_to_device(self, device):
)

def store_output(self, output_name, indexer, vars):
loader = self.make_loader()
return ops.store(
output_name,
indexer(self.output_indexer(vars)),
self.inner_fn(vars),
loader(vars),
mode=self.scatter_mode,
)

Expand Down Expand Up @@ -2088,7 +2101,14 @@ def freeze_layout_with_same_order(self, stride):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_same_order(stride)

def is_zero_elements(self):
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))

def make_loader(self):
# Loading from a zero-element buffer is a no-op
if self.is_zero_elements():
return partial(nop_loader_fn, dtype=self.get_dtype())

def loader(index):
indexer = self.layout.make_indexer()
return ops.load(self.name, indexer(index))
Expand Down

0 comments on commit 8472c24

Please sign in to comment.