From 20c5add13319934128201848fb0a7fbe07e71524 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 14 Aug 2023 11:10:02 -0700 Subject: [PATCH] [export] Refactor `constrain_as_value` and `constrain_as_size` (#106591) Some notable changes: 1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2. 2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591 Approved by: https://github.com/gmagogsfm, https://github.com/ezyang --- aten/src/ATen/native/Constraints.cpp | 51 ++++- aten/src/ATen/native/cuda/Constraints.cu | 15 -- aten/src/ATen/native/native_functions.yaml | 15 +- test/dynamo/test_export.py | 10 +- ...asDecompTest.test_has_decomposition.expect | 1 + test/export/test_export.py | 203 ++++++++++++++---- test/export/test_passes.py | 4 +- .../check_forward_backward_compatibility.py | 1 + test/test_proxy_tensor.py | 8 +- torch/_dynamo/skipfiles.py | 5 + torch/_export/__init__.py | 2 +- torch/_export/constraints.py | 78 ++++--- torch/_export/serde/serialize.py | 7 +- torch/_inductor/decomposition.py | 5 + torch/_meta_registrations.py | 22 +- torch/fx/experimental/symbolic_shapes.py | 104 ++++++--- torch/fx/node.py | 1 + torch/overrides.py | 1 + torchgen/native_function_generation.py | 1 + 19 files changed, 389 insertions(+), 145 deletions(-) delete mode 100644 aten/src/ATen/native/cuda/Constraints.cu diff --git a/aten/src/ATen/native/Constraints.cpp b/aten/src/ATen/native/Constraints.cpp index 3885cb9fadf796..f6a5b5edc41714 100644 --- a/aten/src/ATen/native/Constraints.cpp +++ b/aten/src/ATen/native/Constraints.cpp @@ -1,3 +1,4 @@ +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -15,21 +16,65 @@ #include #include #include +#include +#include #endif namespace at { namespace native { -void sym_constrain_range_cpu( +void sym_constrain_range( const Scalar& size, c10::optional min, - c10::optional max) {} + c10::optional max) { -Tensor _functional_sym_constrain_range_cpu( + int64_t min_val = min.has_value() ? min.value() : std::numeric_limits::min(); + int64_t max_val = max.has_value() ? max.value() : std::numeric_limits::max(); + int64_t size_as_int = size.toInt(); + + TORCH_CHECK( + max_val >= min_val, + "Max must be greater than or equal to min. Got min=", + min_val, + " max=", + max_val + ); + + TORCH_CHECK( + min_val <= size_as_int && size_as_int <= max_val, + "Invalid value range for ", + size_as_int, + " between [", + min_val, + ", ", + max_val, + "]." + ); +} + +Tensor _functional_sym_constrain_range( const Scalar& size, c10::optional min, c10::optional max, const Tensor& dep_token) { + sym_constrain_range(size, min, max); + return dep_token.clone(); +} + +void sym_constrain_range_for_size(const Scalar& size, c10::optional min, c10::optional max) { + int64_t min_val = min.has_value() ? min.value() : 0; + if (max.has_value() && max.value() <= 2) { + TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value()); + } + sym_constrain_range(size, min_val, max); +} + +Tensor _functional_sym_constrain_range_for_size( + const Scalar& size, + c10::optional min, + c10::optional max, + const Tensor& dep_token) { + sym_constrain_range_for_size(size, min, max); return dep_token.clone(); } diff --git a/aten/src/ATen/native/cuda/Constraints.cu b/aten/src/ATen/native/cuda/Constraints.cu deleted file mode 100644 index 2f4bd4b67f43a7..00000000000000 --- a/aten/src/ATen/native/cuda/Constraints.cu +++ /dev/null @@ -1,15 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS - -#include -#include - -namespace at { -namespace native { - -void sym_constrain_range_cuda( - const Scalar& size, - c10::optional min = c10::nullopt, - c10::optional max = c10::nullopt) {} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9cd97a1eead8b7..b2090d0135a706 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -181,14 +181,21 @@ - func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () -- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> () +- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> () dispatch: - CPU: sym_constrain_range_cpu - CUDA: sym_constrain_range_cuda + CompositeExplicitAutograd: sym_constrain_range + +- func: sym_constrain_range_for_size(Scalar size, *, int? min, int? max) -> () + dispatch: + CompositeExplicitAutograd: sym_constrain_range_for_size - func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor dispatch: - CPU: _functional_sym_constrain_range_cpu + CompositeExplicitAutograd: _functional_sym_constrain_range + +- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + dispatch: + CompositeExplicitAutograd: _functional_sym_constrain_range_for_size - func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor dispatch: diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 2d707cffd6a368..7d1fe1d1b67395 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2273,7 +2273,7 @@ def my_dyn_fn(a, b, c): def test_export_preserve_constraints_as_metadata_scalar(self): def f(x, y): b = x.item() - constrain_as_size(b, min=2, max=5) + constrain_as_size(b) return torch.empty((b, y.shape[0])) x = torch.tensor([3]) @@ -2322,7 +2322,7 @@ def test_exported_graph_serialization(self): def f(x, y): b = x.item() - constrain_as_size(b, min=2, max=5) + constrain_as_size(b) return torch.empty((b, y.shape[0])) x = torch.tensor([3]) @@ -2344,11 +2344,11 @@ def f(x, y): def test_export_with_inline_constraints(self): def f(x): a = x.item() - constrain_as_size(a, 4, 7) + constrain_as_value(a, 4, 7) return torch.empty((a, 4)) with self.assertRaisesRegex( - torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]" + RuntimeError, r"Invalid value range for 20 between \[4, 7\]." ) as cm: torch._export.export(f, (torch.tensor([20]),)) @@ -2368,7 +2368,7 @@ def f(x): def test_export_with_inline_constraints_complex(self): def f(x): a = x.item() - constrain_as_size(a, 4, 7) + constrain_as_value(a, 4, 7) empty = torch.empty((a, 4)) return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0) diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 128866309b8c38..ed3bcc91e83543 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -322,6 +322,7 @@ aten::_foreach_zero.out aten::_foreach_zero_ aten::_functional_assert_async.msg aten::_functional_sym_constrain_range +aten::_functional_sym_constrain_range_for_size aten::_fused_adam aten::_fused_adam.out aten::_fused_adam_ diff --git a/test/export/test_export.py b/test/export/test_export.py index f68fc3783eb276..a76cadcd562229 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -5,8 +5,8 @@ import torch import torch._dynamo as torchdynamo from torch._export import export, dynamic_dim, DEFAULT_EXPORT_DYNAMO_CONFIG -from torch._export.utils import register_dataclass_as_pytree_node from torch._export.constraints import constrain_as_size, constrain_as_value +from torch._export.utils import register_dataclass_as_pytree_node from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import run_tests, TestCase from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec @@ -21,7 +21,7 @@ def test_export_inline_constraints(self): def f(x): b = x.item() - constrain_as_size(b, min=2, max=5) + constrain_as_size(b) return torch.full((b, 1), 1) inp = (torch.tensor([3]),) @@ -37,23 +37,6 @@ def f(x): self.assertTrue(torchdynamo.utils.same(ref, res)) def test_export_constraints_error(self): - def invalid_size(x): - b = x.item() - constrain_as_size(b, min=0, max=5) - return torch.full((b, 1), 1) - - inp = (torch.tensor([3]),) - with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"): - export(invalid_size, inp) - - def invalid_input_conflict_with_inline_constraints(x): - b = x.item() - constrain_as_size(b, min=2, max=5) - return torch.full((b, 1), 1) - - inp = (torch.tensor([6]),) - with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"): - export(invalid_input_conflict_with_inline_constraints, inp) def invalid_input_conflict_with_input_constraints(x): return x + 1 @@ -69,16 +52,15 @@ def invalid_input_conflict_with_input_constraints(x): constraints=inp_constraints, ) - def conflicting_constraints(x): b = x.item() - constrain_as_size(b, min=2, max=3) - constrain_as_size(b, min=4, max=5) + constrain_as_size(b) + constrain_as_value(b, min=4, max=5) return torch.full((b, 1), 1) inp = (torch.tensor([3]),) - with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid ranges"): + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 3 between \[4, 5\]"): export(conflicting_constraints, inp) def test_export_assume_static_by_default(self): @@ -222,25 +204,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"): em(x) - def test_export_constrain_static(self): - def f(x, y): - b = x.item() - constrain_as_size(b, min=2, max=5) - c = y.dim() - constrain_as_value(c, min=1, max=3) - z = y[0:c] - return torch.empty((b, y.shape[0])), z - - x = torch.tensor([3]) - y = torch.randn([8, 8, 6]) - example_inputs = (x, y) - constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10] - with self.assertRaisesRegex( - torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " + - "on a value which we evaluated to have a static value of 3. " - ): - export(f, example_inputs, {}, constraints) - def test_not_correct_dim(self): def f(x): return x.cos() @@ -588,5 +551,161 @@ def fn(x): # Intentionally not wrapping `inp` in a tuple to trigger the error _ = export(fn, inp) + def test_constrain_value_with_no_default(self): + def fn(x, y): + n = x.max().item() + constrain_as_value(n) + return y + n + + ep = export(fn, (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3)))) + test_inp = (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3))) + self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp))) + + def test_constrain_value_with_symfloat(self): + def fn(x, y): + n = x.max().item() + constrain_as_value(n) + return y + n + + with self.assertRaisesRegex(torch._dynamo.exc.TorchRuntimeError, "Constraining SymFloat or Symbool is nyi"): + _ = export(fn, (torch.rand(2, 2), torch.rand(2, 3))) + + def test_constrain_size_in_eager(self): + def fn(x, y): + n = x.max().item() + constrain_as_size(n) + return y + n + + ep = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))) + test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) + self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp))) + + def test_constrain_size_with_constrain_value(self): + def fn(x, y): + n = x.max().item() + constrain_as_value(n, 2, 10) + constrain_as_size(n) + return y + n + + # Since we are using constrain_as_value, we expect to raise error when user + # passes in invalid tracing input + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."): + _ = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."): + _ = fn(torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) + + ep = export(fn, (torch.randint(3, 4, (2, 2)), torch.randint(3, 5, (2, 3)))) + with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint"): + test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) + _ = ep(*test_inp) + + def test_constrain_size_with_various_cases(self): + + def case_1(x, y): + n = x.item() + constrain_as_size(n, min=0) + return y.sum() + torch.ones(n, 5).sum() + + def case_2(x, y): + n = x.item() + constrain_as_size(n, min=0, max=6) + return y.sum() + torch.ones(n, 5).sum() + + def case_3(x, y): + n = x.item() + constrain_as_size(n, min=0, max=1) + return y.sum() + torch.ones(n, 5).sum() + + def case_4(x, y): + n = x.item() + constrain_as_size(n, min=2) + return y.sum() + torch.ones(n, 5).sum() + + def case_5(x, y): + n = x.item() + constrain_as_size(n, min=1) + return y.sum() + torch.ones(n, 5).sum() + + ep = export(case_1, (torch.tensor(1), torch.ones(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, inf\]."): + _ = ep(torch.tensor(-1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for -1 between"): + _ = case_1(torch.tensor(-1), torch.randn(4, 5)) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(1), torch.ones(4, 5)), + case_1(torch.tensor(1), torch.ones(4, 5)), + ) + ) + + ep = export(case_2, (torch.tensor(5), torch.randn(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, 6\]."): + _ = ep(torch.tensor(7), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 7 between"): + _ = case_2(torch.tensor(7), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between \[0, 6\]."): + _ = export(case_2, (torch.tensor(9), torch.randn(4, 5))) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between"): + _ = case_2(torch.tensor(9), torch.randn(4, 5)) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(5), torch.ones(4, 5)), + case_2(torch.tensor(5), torch.ones(4, 5)), + ) + ) + + with self.assertRaisesRegex( + torch._dynamo.exc.TorchRuntimeError, + "Maximum value to constrain_as_size must be greater than 2, but was 1" + ): + _ = export(case_3, (torch.tensor(1), torch.randn(4, 5))) + + with self.assertRaisesRegex(RuntimeError, "Max value to constrain_range_for_size must be greater than 2. got: 1"): + _ = case_3(torch.tensor(1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."): + _ = export(case_4, (torch.tensor(1), torch.randn(4, 5))) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."): + _ = case_4(torch.tensor(1), torch.randn(4, 5)) + + ep = export(case_4, (torch.tensor(5), torch.randn(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[2, inf\]."): + _ = ep(torch.tensor(1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1"): + _ = case_4(torch.tensor(1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0 between \[1, 9223372036854775807\]."): + _ = export(case_5, (torch.tensor(0), torch.randn(4, 5))) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(5), torch.ones(4, 5)), + case_4(torch.tensor(5), torch.ones(4, 5)), + ) + ) + + ep = export(case_5, (torch.tensor(5), torch.randn(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[1, inf\]."): + _ = ep(torch.tensor(0), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0"): + _ = case_5(torch.tensor(0), torch.randn(4, 5)) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(5), torch.ones(4, 5)), + case_5(torch.tensor(5), torch.ones(4, 5)), + ) + ) + if __name__ == '__main__': run_tests() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 938d4b6d50840e..4c26dd2240182a 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -13,7 +13,7 @@ from torch.testing import FileCheck from torch._dynamo.eval_frame import is_dynamo_supported from torch._export import export, dynamic_dim -from torch._export.constraints import constrain_as_value, constrain_as_size +from torch._export.constraints import constrain_as_value from torch._export.passes import ( ReplaceViewOpsWithViewCopyOpsPass, ) @@ -346,7 +346,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def test_functionalize_inline_contraints(self) -> None: def f(x): a = x.item() - constrain_as_size(a, 4, 7) + constrain_as_value(a, 4, 7) return torch.empty((a, 4)) ep = torch._export.export(f, (torch.tensor([7]),)) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 420fb9c7705f0d..d7e2768780e4df 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -314,6 +314,7 @@ ("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), + ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 9559c711ecd9ec..2c22a25c3ae50f 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -10,9 +10,10 @@ from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode from torch._decomp import decomposition_table +from torch._export.constraints import constrain_as_size, constrain_as_value from torch.fx.experimental.symbolic_shapes import ( sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, - constrain_range, guard_int, GuardOnDataDependentSymNode + guard_int, GuardOnDataDependentSymNode ) from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db @@ -1041,7 +1042,7 @@ def forward(self, a_1): def test_item_to_constructor(self): def f(a): r = a.item() - constrain_range(r, min=2) + constrain_as_size(r) return torch.empty(r) r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() @@ -1049,6 +1050,7 @@ def f(a): r, """\ def forward(self, a_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None + sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = None, max = None) empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return empty""" # noqa: B950 ) @@ -1127,7 +1129,7 @@ def f(x, mask, params, buffers): for s in p.shape: guard_int(s) x = x[mask] - constrain_range(x.shape[0], min=1) + constrain_as_value(x.shape[0], min=1) for p in params.values(): p.grad = None return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index aadea2fcf2a766..1cf5ef3c33ac55 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -145,6 +145,11 @@ def _module_dir(m: types.ModuleType): _module_dir(torch) + "ao/quantization/pt2e/utils.py", } +FILENAME_ALLOWLIST |= { + _module_dir(torch) + "_export/constraints.py", +} + +# TODO (zhxchen17) Make exportdb importable here. FILENAME_ALLOWLIST |= set( glob.glob(_module_dir(torch) + "_export/db/examples/*.py"), ) | { diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 54c653d0455248..23bc0725d2010c 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -363,7 +363,7 @@ def convert_to_fake(x): # so we serialize them here instead of inside dynamo gm.meta["inline_constraints"] = { k: v - for k, v in fake_mode.shape_env.var_to_range.items() + for k, v in fake_mode.shape_env.runtime_var_to_range.items() if re.match(r"^[if]\d+$", str(k)) } diff --git a/torch/_export/constraints.py b/torch/_export/constraints.py index 6f8d382cccca69..aa46cafcf2ef70 100644 --- a/torch/_export/constraints.py +++ b/torch/_export/constraints.py @@ -1,48 +1,60 @@ -from typing import Optional, Callable, Union +from typing import Optional import torch -from torch import SymInt, SymFloat -from torch._dynamo import allow_in_graph -from torch.fx.experimental.symbolic_shapes import constrain_range_int -from torch.utils._sympy.value_ranges import ValueRangeError - -# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]` -# could cause type error during since `SymInt` or `SymFloat` will be used. -# Here manually specify the type explicitly. -sym_constrain_range: Callable[ - [Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]], - None, -] = torch.sym_constrain_range # type: ignore[assignment] # TODO: we want to hide this min/max stuff under some abstraction similar to # DynamicDim -@allow_in_graph def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None): """ - Add min/max constraint on the intermediate symbol at tracing time + Add min/max constraint on the intermediate symbol at tracing time. If called in eager mode, + it will still check if the input value is within the specified range. """ - - if not isinstance(symbol, SymInt): - constrain_range_int(symbol, min=min, max=max) - else: - sym_constrain_range(symbol, min, max) - - return symbol + torch.sym_constrain_range(symbol, min=min, max=max) # TODO: we want to hide this min/max stuff under some abstraction similar to # DynamicDim -@allow_in_graph -def constrain_as_size(symbol, min: int = 2, max: Optional[int] = None): +def constrain_as_size(symbol, min: Optional[int] = None, max: Optional[int] = None): """ - Add min/max constraint on the intermediate symbol which will be used as a size + This indicates that a given int is size-like, and can be used in any context where a size is expected. + You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() + which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve + GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts. + + This function has unusual semantics which distinguish it from constrain_as_value. + Specifically, at compile-time, we will unsoundly assume that the resulting int is always >= 2. + As a result, max value you pass in should always be greater than 2. + This makes it easier to use the unbacked int in size contexts, as we will often attempt to guard on a size being zero/one + (e.g., when computing the contiguity of a tensor, or testing if broadcasting can occur), + which will not work on unbacked SymInts. Assuming that the int is >= 2 allows us to + report False to these tests. Although this is technically unsound, + in practice we observe that if your program works for all sizes >= 2, + it probably works for zero and one too. The reason specifically assume size is >= 2 is because + lot of PyTorch code is specialized for 0 and 1 which could result in not general graphs. + At runtime, we only assert that the user provided min/max values are respected. + + To demonstrate in a scenario, suppose you do + ``` + # Case 1 + # This will assume symbol is between [2, inf) at compile time, but [0, inf) at runtime + constrain_as_size(symbol, min=0) + + # Case 2 + # This will assume symbol is between [2, N] at compile time, but [0, N] at runtime + constrain_as_size(symbol, min=0, max=N) + + # Case 3 + # This is not valid case as max is <= 2 + constrain_as_size(symbol, min=0, max=1) + + # Case 4 + # This will assume symbol is between [2, inf) at compile time, AND [2, inf) at runtime + constrain_as_size(symbol, min=2) + + # Case 5 + # This will assume symbol is between [2, inf) at compile time, but [1, inf) at runtime + constrain_as_size(symbol, min=1) + ``` """ - - # TODO: we should investigate turning off 0/1 specialization for unbacked - # SymInts - if min < 2: - raise ValueRangeError( - "Unable to set min size to be <= 2 because we specialize on 0/1 sizes." - ) - return constrain_as_value(symbol, min, max) + torch.sym_constrain_range_for_size(symbol, min=min, max=max) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 6ebb1a2a598f03..4838ac3175335d 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -806,7 +806,12 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: if vr := self.symbol_name_to_range.get(val.expr_str): symbolic_shapes._constrain_symbol_range( - self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type] + self.shape_env, + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + runtime_min=vr.lower, # type: ignore[arg-type] + runtime_max=vr.upper # type: ignore[arg-type] ) return self.shape_env.create_symintnode(sym, hint=val.hint) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index b5e58377d8a307..e6b362de49e3af 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -83,6 +83,11 @@ def functional_assert_async_msg_decomp(tensor, msg): return +@register_decomposition([aten.sym_constrain_range_for_size.default]) +def sym_constrain_range_for_size(symbol, *, min=None, max=None): + return + + @register_decomposition([aten.clamp]) @pw_cast_for_opmath def clamp(x, min=None, max=None): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 3772eb8de27033..241bbcbf53329f 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4,7 +4,7 @@ import torch import torch._prims_common as utils -from torch import Tensor +from torch import SymBool, SymFloat, Tensor from torch._decomp import ( _add_op_to_registry, _convert_out_params, @@ -30,7 +30,10 @@ out_wrapper, ) from torch._refs import _broadcast_shapes, _maybe_broadcast -from torch.fx.experimental.symbolic_shapes import constrain_range +from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + constrain_range, +) from torch.utils._pytree import tree_map @@ -424,6 +427,8 @@ def make_dep_token( @register_meta(aten.sym_constrain_range.default) def sym_constrain_range(size, min=None, max=None): + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") constrain_range(size, min=min, max=max) @@ -433,6 +438,19 @@ def functional_sym_constrain_range(size, min=None, max=None, dep_token=None): return dep_token +@register_meta(aten.sym_constrain_range_for_size.default) +def sym_constrain_range_for_size(size, min=None, max=None): + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") + _constrain_range_for_size(size, min=min, max=max) + + +@register_meta(aten._functional_sym_constrain_range_for_size.default) +def functional_sym_constrain_range_for_size(size, min, max, dep_token): + aten.sym_constrain_range_for_size(size, min=min, max=max) + return dep_token + + @register_meta(aten._functional_assert_async.msg) def functional_assert_async_meta(val, assert_msg, dep_token): return dep_token diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 14a1c247822f6e..1323a10eca5560 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -32,7 +32,7 @@ SymFloat, SymInt, ) -from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode +from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError @@ -304,13 +304,57 @@ def guard_scalar(a): else: raise AssertionError(f"unrecognized scalar {a}") -def _constrain_symbol_range(shape_env, s: sympy.Symbol, min: int, max: int): +def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int, runtime_min: int, runtime_max: int): if r := shape_env.var_to_range.get(s, None): shape_env.var_to_range[s] = ValueRanges( - builtins.max(r.lower, min), builtins.min(r.upper, max) + builtins.max(r.lower, compiler_min), builtins.min(r.upper, compiler_max) ) else: - shape_env.var_to_range[s] = ValueRanges(min, max) + shape_env.var_to_range[s] = ValueRanges(compiler_min, compiler_max) + + if r := shape_env.runtime_var_to_range.get(s, None): + shape_env.runtime_var_to_range[s] = ValueRanges( + builtins.max(r.lower, runtime_min), builtins.min(r.upper, runtime_max) + ) + else: + shape_env.runtime_var_to_range[s] = ValueRanges(runtime_min, runtime_max) + +def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None): + """ + This function is NOT INTENDED to be used by itself. + """ + + if isinstance(a, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat/SymBool is nyi") + + assert isinstance(a, SymInt) + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + + if min is None: + min = 0 + if max is None: + max = sympy.oo + + if max <= 2: + raise ValueError(f"Maximum value to constrain_as_size must be greater than 2, but was {max}") + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + compiler_min = 2 if min < 2 else min + + _constrain_symbol_range( + a.node.shape_env, + a.node.expr, + compiler_min=compiler_min, + compiler_max=max, + runtime_min=min, + runtime_max=max + ) + # inclusive both ways def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): @@ -350,8 +394,16 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): min = -sympy.oo if max is None: max = sympy.oo - if not isinstance(a, SymInt): - constrain_range_int(a, min=min, max=max) + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + if isinstance(a, int): + if not (min <= a <= max): + raise ValueError(f"Invalid value {a} for range [{min}:{max}]") return if isinstance(a.node.expr, sympy.Integer): @@ -364,35 +416,15 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): # semantics that this is an "unchecked" assert (but it this actually # something useful? Might be better to restrict only for unbacked # SymInt). - _constrain_symbol_range(a.node.shape_env, a.node.expr, min, max) - -def constrain_range_int(a, *, min, max): - """ - Constrain range on concrete int value. - This can happens for the following scenarios: - - Eager mode execution and real int value is provided. - - During tracing the traced symbol is resolved as a static integer (see - PR #101655 for more details). - """ - if min is None: - min = -sympy.oo - if max is None: - max = sympy.oo - - assert not isinstance(a, SymInt) - if not (min <= a <= max): - raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]") + _constrain_symbol_range( + a.node.shape_env, + a.node.expr, + compiler_min=min, + compiler_max=max, + runtime_min=min, + runtime_max=max + ) - if ( - (fake_mode := detect_fake_mode()) is not None and - getattr(fake_mode, "shape_env", None) is not None - ): - # If we are tracing with a fake mode then add this integer to the - # shape_env's var_to_range - sym_integer = sympy.Integer(a) - shape_env = fake_mode.shape_env - _constrain_symbol_range(shape_env, sym_integer, min, max) - shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack() def constrain_unify(a, b): """ @@ -1938,6 +1970,10 @@ def __init__( # range may contain ints which may not actually appear in # practice self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {} + # Maps symbolic ints to their min/max range for runtime checks. + # This is because we assume a graph generated with N=2 is general enough + # for N < 2. Therefore, it will be too strict to assert N=2 at runtime. + self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {} self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {} self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {} # Maps symbolic ints to the guards that refine their lower/upper diff --git a/torch/fx/node.py b/torch/fx/node.py index 023e5761b60c0c..09a3b2c2b7e60d 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -36,6 +36,7 @@ _ops.aten._assert_async.msg, _ops.aten.copy_.default, _ops.aten.sym_constrain_range.default, + _ops.aten.sym_constrain_range_for_size.default, _ops.profiler._record_function_enter, _ops.profiler._record_function_enter_new, _ops.profiler._record_function_exit} diff --git a/torch/overrides.py b/torch/overrides.py index 7eee2f4f184a44..cdaa5ad4ce4a3b 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -186,6 +186,7 @@ def get_ignored_functions() -> Set[Callable]: torch.sym_min, torch.sym_not, torch.sym_constrain_range, + torch.sym_constrain_range_for_size, torch.tril_indices, torch.triu_indices, torch.vander, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 9a0a75c023b899..afd04b5fb016d7 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -74,6 +74,7 @@ "record_stream", # no return "sparse_dim", # returns an int "sym_constrain_range", # no return + "sym_constrain_range_for_size", # no return "_nested_tensor_storage_offsets", # returns a vector of ints "_chunk_grad_outputs_efficient_attention", # returns a bool "_fused_sdp_choice", # returns an int