Skip to content

Commit

Permalink
Add missing evaluate_expr for slice_scatter, slight refactor (pytorch…
Browse files Browse the repository at this point in the history
…#105714)

The substantive change is adding slice_scatter to use evaluate_expr
(and I add a test for it).

While I'm at it, I do some cleanup: provide sizevars.evaluate_expr
directly, and rewrite all sites to use it consistently.

Fixes pytorch#105524

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#105714
Approved by: https://github.com/Skylion007
  • Loading branch information
ezyang authored and pytorchmergebot committed Jul 22, 2023
1 parent f5def50 commit 53a4b26
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
13 changes: 13 additions & 0 deletions test/inductor/test_torchinductor_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,19 @@ def pad_same(x, k, s, d=(1, 1), value=0):
ref = pad_same(x, (5, 5), (2, 2))
self.assertEqual(res, ref, atol=0, rtol=0)

def test_slice_scatter(self, device):
def fn(i):
s3 = i.size(0)
x = torch.ones(64, s3, device=device)
y = torch.ones(64, s3 // 2, device=device)
return torch.slice_scatter(x, y, 1, s3 // 2, 2 * (s3 // 2))

a = torch.randn(16, device=device)
cfn = self.compile_fn(fn)
expect = fn(a)
actual = cfn(a)
self.assertEqual(expect, actual)

def test_slice_index_changing_sign(self, device):
def fn(x, y):
y0, y1 = y.shape
Expand Down
15 changes: 7 additions & 8 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def squeeze(x, dim=None):

new_shape = []
for d, s in enumerate(x.get_size()):
if not (d in dims and V.graph.sizevars.shape_env.evaluate_expr(sympy.Eq(s, 1))):
if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))):
new_shape.append(s)

# squeeze does nothing if the size isn't 1
Expand Down Expand Up @@ -838,9 +838,9 @@ def slice_(x, dim=0, start=0, end=2**63, step=1):
assert isinstance(x, TensorBox)
dim = _validate_dim(x, dim, 0)
dim_size = x.get_size()[dim]
if V.graph.sizevars.shape_env.evaluate_expr(sympy.Lt(start + dim_size, 0)):
if V.graph.sizevars.evaluate_expr(sympy.Lt(start + dim_size, 0)):
start = 0
if V.graph.sizevars.shape_env.evaluate_expr(sympy.Lt(end + dim_size, 0)):
if V.graph.sizevars.evaluate_expr(sympy.Lt(end + dim_size, 0)):
end = 0
return TensorBox(ir.SliceView.create(x.data, dim, start, end, step))

Expand Down Expand Up @@ -960,8 +960,7 @@ def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
)

evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
offset_negative = evaluate_expr(sympy.Lt(offset, 0))
offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0))
if offset_negative:
diag_size = max(min(original_shape[dim1] + offset, original_shape[dim2]), 0)
else:
Expand Down Expand Up @@ -1888,7 +1887,7 @@ def select_scatter(x, src, dim: int, index: int):
assert x.get_dtype() == src.get_dtype()
x_loader = x.make_loader()
dim = _validate_dim(x, dim, 0)
if index < 0:
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
index = index + x.get_size()[dim]
V.graph.sizevars.guard_leq(0, index)
V.graph.sizevars.guard_lt(index, x.get_size()[dim])
Expand Down Expand Up @@ -1919,9 +1918,9 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
x_loader = x.make_loader()
dim = _validate_dim(x, dim, 0)
dim_size = x.get_size()[dim]
if start is not None and start < 0:
if start is not None and V.graph.sizevars.evaluate_expr(sympy.Lt(start, 0)):
start = start + dim_size
if end is not None and end < 0:
if end is not None and V.graph.sizevars.evaluate_expr(sympy.Lt(end, 0)):
end = end + dim_size
if start is None:
start = 0
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,14 @@ def guard_lt(self, left: Expr, right: Expr) -> None:
# (NB: not necessarily an Expr) and return what the concrete result
# is, guarding on the expression being that result

# NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b)
# as this will ensure that you actually have a sympy'ified expression,
# and will prevent you from incorrectly writing evaluate_expr(a == b)
# which does the wrong thing if a or b is a sympy expression
def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
return self.shape_env.evaluate_expr(sympy.sympify(left))

def evaluate_min(self, left: Expr, right: Expr) -> Expr:
"""return the smaller of left and right, and guard on that choice"""
lv = self.size_hint(left)
Expand Down

0 comments on commit 53a4b26

Please sign in to comment.