diff --git a/aten/src/ATen/native/Constraints.cpp b/aten/src/ATen/native/Constraints.cpp index 3885cb9fadf796..f6a5b5edc41714 100644 --- a/aten/src/ATen/native/Constraints.cpp +++ b/aten/src/ATen/native/Constraints.cpp @@ -1,3 +1,4 @@ +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -15,21 +16,65 @@ #include #include #include +#include +#include #endif namespace at { namespace native { -void sym_constrain_range_cpu( +void sym_constrain_range( const Scalar& size, c10::optional min, - c10::optional max) {} + c10::optional max) { -Tensor _functional_sym_constrain_range_cpu( + int64_t min_val = min.has_value() ? min.value() : std::numeric_limits::min(); + int64_t max_val = max.has_value() ? max.value() : std::numeric_limits::max(); + int64_t size_as_int = size.toInt(); + + TORCH_CHECK( + max_val >= min_val, + "Max must be greater than or equal to min. Got min=", + min_val, + " max=", + max_val + ); + + TORCH_CHECK( + min_val <= size_as_int && size_as_int <= max_val, + "Invalid value range for ", + size_as_int, + " between [", + min_val, + ", ", + max_val, + "]." + ); +} + +Tensor _functional_sym_constrain_range( const Scalar& size, c10::optional min, c10::optional max, const Tensor& dep_token) { + sym_constrain_range(size, min, max); + return dep_token.clone(); +} + +void sym_constrain_range_for_size(const Scalar& size, c10::optional min, c10::optional max) { + int64_t min_val = min.has_value() ? min.value() : 0; + if (max.has_value() && max.value() <= 2) { + TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value()); + } + sym_constrain_range(size, min_val, max); +} + +Tensor _functional_sym_constrain_range_for_size( + const Scalar& size, + c10::optional min, + c10::optional max, + const Tensor& dep_token) { + sym_constrain_range_for_size(size, min, max); return dep_token.clone(); } diff --git a/aten/src/ATen/native/cuda/Constraints.cu b/aten/src/ATen/native/cuda/Constraints.cu deleted file mode 100644 index 2f4bd4b67f43a7..00000000000000 --- a/aten/src/ATen/native/cuda/Constraints.cu +++ /dev/null @@ -1,15 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS - -#include -#include - -namespace at { -namespace native { - -void sym_constrain_range_cuda( - const Scalar& size, - c10::optional min = c10::nullopt, - c10::optional max = c10::nullopt) {} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9cd97a1eead8b7..b2090d0135a706 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -181,14 +181,21 @@ - func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () -- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> () +- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> () dispatch: - CPU: sym_constrain_range_cpu - CUDA: sym_constrain_range_cuda + CompositeExplicitAutograd: sym_constrain_range + +- func: sym_constrain_range_for_size(Scalar size, *, int? min, int? max) -> () + dispatch: + CompositeExplicitAutograd: sym_constrain_range_for_size - func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor dispatch: - CPU: _functional_sym_constrain_range_cpu + CompositeExplicitAutograd: _functional_sym_constrain_range + +- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + dispatch: + CompositeExplicitAutograd: _functional_sym_constrain_range_for_size - func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor dispatch: diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 2d707cffd6a368..7d1fe1d1b67395 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2273,7 +2273,7 @@ def my_dyn_fn(a, b, c): def test_export_preserve_constraints_as_metadata_scalar(self): def f(x, y): b = x.item() - constrain_as_size(b, min=2, max=5) + constrain_as_size(b) return torch.empty((b, y.shape[0])) x = torch.tensor([3]) @@ -2322,7 +2322,7 @@ def test_exported_graph_serialization(self): def f(x, y): b = x.item() - constrain_as_size(b, min=2, max=5) + constrain_as_size(b) return torch.empty((b, y.shape[0])) x = torch.tensor([3]) @@ -2344,11 +2344,11 @@ def f(x, y): def test_export_with_inline_constraints(self): def f(x): a = x.item() - constrain_as_size(a, 4, 7) + constrain_as_value(a, 4, 7) return torch.empty((a, 4)) with self.assertRaisesRegex( - torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]" + RuntimeError, r"Invalid value range for 20 between \[4, 7\]." ) as cm: torch._export.export(f, (torch.tensor([20]),)) @@ -2368,7 +2368,7 @@ def f(x): def test_export_with_inline_constraints_complex(self): def f(x): a = x.item() - constrain_as_size(a, 4, 7) + constrain_as_value(a, 4, 7) empty = torch.empty((a, 4)) return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0) diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 128866309b8c38..ed3bcc91e83543 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -322,6 +322,7 @@ aten::_foreach_zero.out aten::_foreach_zero_ aten::_functional_assert_async.msg aten::_functional_sym_constrain_range +aten::_functional_sym_constrain_range_for_size aten::_fused_adam aten::_fused_adam.out aten::_fused_adam_ diff --git a/test/export/test_export.py b/test/export/test_export.py index f68fc3783eb276..a76cadcd562229 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -5,8 +5,8 @@ import torch import torch._dynamo as torchdynamo from torch._export import export, dynamic_dim, DEFAULT_EXPORT_DYNAMO_CONFIG -from torch._export.utils import register_dataclass_as_pytree_node from torch._export.constraints import constrain_as_size, constrain_as_value +from torch._export.utils import register_dataclass_as_pytree_node from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import run_tests, TestCase from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec @@ -21,7 +21,7 @@ def test_export_inline_constraints(self): def f(x): b = x.item() - constrain_as_size(b, min=2, max=5) + constrain_as_size(b) return torch.full((b, 1), 1) inp = (torch.tensor([3]),) @@ -37,23 +37,6 @@ def f(x): self.assertTrue(torchdynamo.utils.same(ref, res)) def test_export_constraints_error(self): - def invalid_size(x): - b = x.item() - constrain_as_size(b, min=0, max=5) - return torch.full((b, 1), 1) - - inp = (torch.tensor([3]),) - with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"): - export(invalid_size, inp) - - def invalid_input_conflict_with_inline_constraints(x): - b = x.item() - constrain_as_size(b, min=2, max=5) - return torch.full((b, 1), 1) - - inp = (torch.tensor([6]),) - with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"): - export(invalid_input_conflict_with_inline_constraints, inp) def invalid_input_conflict_with_input_constraints(x): return x + 1 @@ -69,16 +52,15 @@ def invalid_input_conflict_with_input_constraints(x): constraints=inp_constraints, ) - def conflicting_constraints(x): b = x.item() - constrain_as_size(b, min=2, max=3) - constrain_as_size(b, min=4, max=5) + constrain_as_size(b) + constrain_as_value(b, min=4, max=5) return torch.full((b, 1), 1) inp = (torch.tensor([3]),) - with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid ranges"): + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 3 between \[4, 5\]"): export(conflicting_constraints, inp) def test_export_assume_static_by_default(self): @@ -222,25 +204,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"): em(x) - def test_export_constrain_static(self): - def f(x, y): - b = x.item() - constrain_as_size(b, min=2, max=5) - c = y.dim() - constrain_as_value(c, min=1, max=3) - z = y[0:c] - return torch.empty((b, y.shape[0])), z - - x = torch.tensor([3]) - y = torch.randn([8, 8, 6]) - example_inputs = (x, y) - constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10] - with self.assertRaisesRegex( - torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " + - "on a value which we evaluated to have a static value of 3. " - ): - export(f, example_inputs, {}, constraints) - def test_not_correct_dim(self): def f(x): return x.cos() @@ -588,5 +551,161 @@ def fn(x): # Intentionally not wrapping `inp` in a tuple to trigger the error _ = export(fn, inp) + def test_constrain_value_with_no_default(self): + def fn(x, y): + n = x.max().item() + constrain_as_value(n) + return y + n + + ep = export(fn, (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3)))) + test_inp = (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3))) + self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp))) + + def test_constrain_value_with_symfloat(self): + def fn(x, y): + n = x.max().item() + constrain_as_value(n) + return y + n + + with self.assertRaisesRegex(torch._dynamo.exc.TorchRuntimeError, "Constraining SymFloat or Symbool is nyi"): + _ = export(fn, (torch.rand(2, 2), torch.rand(2, 3))) + + def test_constrain_size_in_eager(self): + def fn(x, y): + n = x.max().item() + constrain_as_size(n) + return y + n + + ep = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))) + test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) + self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp))) + + def test_constrain_size_with_constrain_value(self): + def fn(x, y): + n = x.max().item() + constrain_as_value(n, 2, 10) + constrain_as_size(n) + return y + n + + # Since we are using constrain_as_value, we expect to raise error when user + # passes in invalid tracing input + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."): + _ = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."): + _ = fn(torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) + + ep = export(fn, (torch.randint(3, 4, (2, 2)), torch.randint(3, 5, (2, 3)))) + with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint"): + test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) + _ = ep(*test_inp) + + def test_constrain_size_with_various_cases(self): + + def case_1(x, y): + n = x.item() + constrain_as_size(n, min=0) + return y.sum() + torch.ones(n, 5).sum() + + def case_2(x, y): + n = x.item() + constrain_as_size(n, min=0, max=6) + return y.sum() + torch.ones(n, 5).sum() + + def case_3(x, y): + n = x.item() + constrain_as_size(n, min=0, max=1) + return y.sum() + torch.ones(n, 5).sum() + + def case_4(x, y): + n = x.item() + constrain_as_size(n, min=2) + return y.sum() + torch.ones(n, 5).sum() + + def case_5(x, y): + n = x.item() + constrain_as_size(n, min=1) + return y.sum() + torch.ones(n, 5).sum() + + ep = export(case_1, (torch.tensor(1), torch.ones(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, inf\]."): + _ = ep(torch.tensor(-1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for -1 between"): + _ = case_1(torch.tensor(-1), torch.randn(4, 5)) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(1), torch.ones(4, 5)), + case_1(torch.tensor(1), torch.ones(4, 5)), + ) + ) + + ep = export(case_2, (torch.tensor(5), torch.randn(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, 6\]."): + _ = ep(torch.tensor(7), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 7 between"): + _ = case_2(torch.tensor(7), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between \[0, 6\]."): + _ = export(case_2, (torch.tensor(9), torch.randn(4, 5))) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between"): + _ = case_2(torch.tensor(9), torch.randn(4, 5)) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(5), torch.ones(4, 5)), + case_2(torch.tensor(5), torch.ones(4, 5)), + ) + ) + + with self.assertRaisesRegex( + torch._dynamo.exc.TorchRuntimeError, + "Maximum value to constrain_as_size must be greater than 2, but was 1" + ): + _ = export(case_3, (torch.tensor(1), torch.randn(4, 5))) + + with self.assertRaisesRegex(RuntimeError, "Max value to constrain_range_for_size must be greater than 2. got: 1"): + _ = case_3(torch.tensor(1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."): + _ = export(case_4, (torch.tensor(1), torch.randn(4, 5))) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."): + _ = case_4(torch.tensor(1), torch.randn(4, 5)) + + ep = export(case_4, (torch.tensor(5), torch.randn(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[2, inf\]."): + _ = ep(torch.tensor(1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1"): + _ = case_4(torch.tensor(1), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0 between \[1, 9223372036854775807\]."): + _ = export(case_5, (torch.tensor(0), torch.randn(4, 5))) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(5), torch.ones(4, 5)), + case_4(torch.tensor(5), torch.ones(4, 5)), + ) + ) + + ep = export(case_5, (torch.tensor(5), torch.randn(4, 5))) + with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[1, inf\]."): + _ = ep(torch.tensor(0), torch.randn(4, 5)) + + with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0"): + _ = case_5(torch.tensor(0), torch.randn(4, 5)) + + self.assertTrue( + torch.allclose( + ep(torch.tensor(5), torch.ones(4, 5)), + case_5(torch.tensor(5), torch.ones(4, 5)), + ) + ) + if __name__ == '__main__': run_tests() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 938d4b6d50840e..4c26dd2240182a 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -13,7 +13,7 @@ from torch.testing import FileCheck from torch._dynamo.eval_frame import is_dynamo_supported from torch._export import export, dynamic_dim -from torch._export.constraints import constrain_as_value, constrain_as_size +from torch._export.constraints import constrain_as_value from torch._export.passes import ( ReplaceViewOpsWithViewCopyOpsPass, ) @@ -346,7 +346,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def test_functionalize_inline_contraints(self) -> None: def f(x): a = x.item() - constrain_as_size(a, 4, 7) + constrain_as_value(a, 4, 7) return torch.empty((a, 4)) ep = torch._export.export(f, (torch.tensor([7]),)) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 420fb9c7705f0d..d7e2768780e4df 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -314,6 +314,7 @@ ("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), + ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 9559c711ecd9ec..2c22a25c3ae50f 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -10,9 +10,10 @@ from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode from torch._decomp import decomposition_table +from torch._export.constraints import constrain_as_size, constrain_as_value from torch.fx.experimental.symbolic_shapes import ( sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, - constrain_range, guard_int, GuardOnDataDependentSymNode + guard_int, GuardOnDataDependentSymNode ) from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db @@ -1041,7 +1042,7 @@ def forward(self, a_1): def test_item_to_constructor(self): def f(a): r = a.item() - constrain_range(r, min=2) + constrain_as_size(r) return torch.empty(r) r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() @@ -1049,6 +1050,7 @@ def f(a): r, """\ def forward(self, a_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None + sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = None, max = None) empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return empty""" # noqa: B950 ) @@ -1127,7 +1129,7 @@ def f(x, mask, params, buffers): for s in p.shape: guard_int(s) x = x[mask] - constrain_range(x.shape[0], min=1) + constrain_as_value(x.shape[0], min=1) for p in params.values(): p.grad = None return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index aadea2fcf2a766..1cf5ef3c33ac55 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -145,6 +145,11 @@ def _module_dir(m: types.ModuleType): _module_dir(torch) + "ao/quantization/pt2e/utils.py", } +FILENAME_ALLOWLIST |= { + _module_dir(torch) + "_export/constraints.py", +} + +# TODO (zhxchen17) Make exportdb importable here. FILENAME_ALLOWLIST |= set( glob.glob(_module_dir(torch) + "_export/db/examples/*.py"), ) | { diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 54c653d0455248..23bc0725d2010c 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -363,7 +363,7 @@ def convert_to_fake(x): # so we serialize them here instead of inside dynamo gm.meta["inline_constraints"] = { k: v - for k, v in fake_mode.shape_env.var_to_range.items() + for k, v in fake_mode.shape_env.runtime_var_to_range.items() if re.match(r"^[if]\d+$", str(k)) } diff --git a/torch/_export/constraints.py b/torch/_export/constraints.py index 6f8d382cccca69..aa46cafcf2ef70 100644 --- a/torch/_export/constraints.py +++ b/torch/_export/constraints.py @@ -1,48 +1,60 @@ -from typing import Optional, Callable, Union +from typing import Optional import torch -from torch import SymInt, SymFloat -from torch._dynamo import allow_in_graph -from torch.fx.experimental.symbolic_shapes import constrain_range_int -from torch.utils._sympy.value_ranges import ValueRangeError - -# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]` -# could cause type error during since `SymInt` or `SymFloat` will be used. -# Here manually specify the type explicitly. -sym_constrain_range: Callable[ - [Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]], - None, -] = torch.sym_constrain_range # type: ignore[assignment] # TODO: we want to hide this min/max stuff under some abstraction similar to # DynamicDim -@allow_in_graph def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None): """ - Add min/max constraint on the intermediate symbol at tracing time + Add min/max constraint on the intermediate symbol at tracing time. If called in eager mode, + it will still check if the input value is within the specified range. """ - - if not isinstance(symbol, SymInt): - constrain_range_int(symbol, min=min, max=max) - else: - sym_constrain_range(symbol, min, max) - - return symbol + torch.sym_constrain_range(symbol, min=min, max=max) # TODO: we want to hide this min/max stuff under some abstraction similar to # DynamicDim -@allow_in_graph -def constrain_as_size(symbol, min: int = 2, max: Optional[int] = None): +def constrain_as_size(symbol, min: Optional[int] = None, max: Optional[int] = None): """ - Add min/max constraint on the intermediate symbol which will be used as a size + This indicates that a given int is size-like, and can be used in any context where a size is expected. + You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() + which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve + GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts. + + This function has unusual semantics which distinguish it from constrain_as_value. + Specifically, at compile-time, we will unsoundly assume that the resulting int is always >= 2. + As a result, max value you pass in should always be greater than 2. + This makes it easier to use the unbacked int in size contexts, as we will often attempt to guard on a size being zero/one + (e.g., when computing the contiguity of a tensor, or testing if broadcasting can occur), + which will not work on unbacked SymInts. Assuming that the int is >= 2 allows us to + report False to these tests. Although this is technically unsound, + in practice we observe that if your program works for all sizes >= 2, + it probably works for zero and one too. The reason specifically assume size is >= 2 is because + lot of PyTorch code is specialized for 0 and 1 which could result in not general graphs. + At runtime, we only assert that the user provided min/max values are respected. + + To demonstrate in a scenario, suppose you do + ``` + # Case 1 + # This will assume symbol is between [2, inf) at compile time, but [0, inf) at runtime + constrain_as_size(symbol, min=0) + + # Case 2 + # This will assume symbol is between [2, N] at compile time, but [0, N] at runtime + constrain_as_size(symbol, min=0, max=N) + + # Case 3 + # This is not valid case as max is <= 2 + constrain_as_size(symbol, min=0, max=1) + + # Case 4 + # This will assume symbol is between [2, inf) at compile time, AND [2, inf) at runtime + constrain_as_size(symbol, min=2) + + # Case 5 + # This will assume symbol is between [2, inf) at compile time, but [1, inf) at runtime + constrain_as_size(symbol, min=1) + ``` """ - - # TODO: we should investigate turning off 0/1 specialization for unbacked - # SymInts - if min < 2: - raise ValueRangeError( - "Unable to set min size to be <= 2 because we specialize on 0/1 sizes." - ) - return constrain_as_value(symbol, min, max) + torch.sym_constrain_range_for_size(symbol, min=min, max=max) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 6ebb1a2a598f03..4838ac3175335d 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -806,7 +806,12 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: if vr := self.symbol_name_to_range.get(val.expr_str): symbolic_shapes._constrain_symbol_range( - self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type] + self.shape_env, + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + runtime_min=vr.lower, # type: ignore[arg-type] + runtime_max=vr.upper # type: ignore[arg-type] ) return self.shape_env.create_symintnode(sym, hint=val.hint) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index b5e58377d8a307..e6b362de49e3af 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -83,6 +83,11 @@ def functional_assert_async_msg_decomp(tensor, msg): return +@register_decomposition([aten.sym_constrain_range_for_size.default]) +def sym_constrain_range_for_size(symbol, *, min=None, max=None): + return + + @register_decomposition([aten.clamp]) @pw_cast_for_opmath def clamp(x, min=None, max=None): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 3772eb8de27033..241bbcbf53329f 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4,7 +4,7 @@ import torch import torch._prims_common as utils -from torch import Tensor +from torch import SymBool, SymFloat, Tensor from torch._decomp import ( _add_op_to_registry, _convert_out_params, @@ -30,7 +30,10 @@ out_wrapper, ) from torch._refs import _broadcast_shapes, _maybe_broadcast -from torch.fx.experimental.symbolic_shapes import constrain_range +from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + constrain_range, +) from torch.utils._pytree import tree_map @@ -424,6 +427,8 @@ def make_dep_token( @register_meta(aten.sym_constrain_range.default) def sym_constrain_range(size, min=None, max=None): + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") constrain_range(size, min=min, max=max) @@ -433,6 +438,19 @@ def functional_sym_constrain_range(size, min=None, max=None, dep_token=None): return dep_token +@register_meta(aten.sym_constrain_range_for_size.default) +def sym_constrain_range_for_size(size, min=None, max=None): + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") + _constrain_range_for_size(size, min=min, max=max) + + +@register_meta(aten._functional_sym_constrain_range_for_size.default) +def functional_sym_constrain_range_for_size(size, min, max, dep_token): + aten.sym_constrain_range_for_size(size, min=min, max=max) + return dep_token + + @register_meta(aten._functional_assert_async.msg) def functional_assert_async_meta(val, assert_msg, dep_token): return dep_token diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 14a1c247822f6e..1323a10eca5560 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -32,7 +32,7 @@ SymFloat, SymInt, ) -from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode +from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError @@ -304,13 +304,57 @@ def guard_scalar(a): else: raise AssertionError(f"unrecognized scalar {a}") -def _constrain_symbol_range(shape_env, s: sympy.Symbol, min: int, max: int): +def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int, runtime_min: int, runtime_max: int): if r := shape_env.var_to_range.get(s, None): shape_env.var_to_range[s] = ValueRanges( - builtins.max(r.lower, min), builtins.min(r.upper, max) + builtins.max(r.lower, compiler_min), builtins.min(r.upper, compiler_max) ) else: - shape_env.var_to_range[s] = ValueRanges(min, max) + shape_env.var_to_range[s] = ValueRanges(compiler_min, compiler_max) + + if r := shape_env.runtime_var_to_range.get(s, None): + shape_env.runtime_var_to_range[s] = ValueRanges( + builtins.max(r.lower, runtime_min), builtins.min(r.upper, runtime_max) + ) + else: + shape_env.runtime_var_to_range[s] = ValueRanges(runtime_min, runtime_max) + +def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None): + """ + This function is NOT INTENDED to be used by itself. + """ + + if isinstance(a, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat/SymBool is nyi") + + assert isinstance(a, SymInt) + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + + if min is None: + min = 0 + if max is None: + max = sympy.oo + + if max <= 2: + raise ValueError(f"Maximum value to constrain_as_size must be greater than 2, but was {max}") + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + compiler_min = 2 if min < 2 else min + + _constrain_symbol_range( + a.node.shape_env, + a.node.expr, + compiler_min=compiler_min, + compiler_max=max, + runtime_min=min, + runtime_max=max + ) + # inclusive both ways def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): @@ -350,8 +394,16 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): min = -sympy.oo if max is None: max = sympy.oo - if not isinstance(a, SymInt): - constrain_range_int(a, min=min, max=max) + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + if isinstance(a, int): + if not (min <= a <= max): + raise ValueError(f"Invalid value {a} for range [{min}:{max}]") return if isinstance(a.node.expr, sympy.Integer): @@ -364,35 +416,15 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): # semantics that this is an "unchecked" assert (but it this actually # something useful? Might be better to restrict only for unbacked # SymInt). - _constrain_symbol_range(a.node.shape_env, a.node.expr, min, max) - -def constrain_range_int(a, *, min, max): - """ - Constrain range on concrete int value. - This can happens for the following scenarios: - - Eager mode execution and real int value is provided. - - During tracing the traced symbol is resolved as a static integer (see - PR #101655 for more details). - """ - if min is None: - min = -sympy.oo - if max is None: - max = sympy.oo - - assert not isinstance(a, SymInt) - if not (min <= a <= max): - raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]") + _constrain_symbol_range( + a.node.shape_env, + a.node.expr, + compiler_min=min, + compiler_max=max, + runtime_min=min, + runtime_max=max + ) - if ( - (fake_mode := detect_fake_mode()) is not None and - getattr(fake_mode, "shape_env", None) is not None - ): - # If we are tracing with a fake mode then add this integer to the - # shape_env's var_to_range - sym_integer = sympy.Integer(a) - shape_env = fake_mode.shape_env - _constrain_symbol_range(shape_env, sym_integer, min, max) - shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack() def constrain_unify(a, b): """ @@ -1938,6 +1970,10 @@ def __init__( # range may contain ints which may not actually appear in # practice self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {} + # Maps symbolic ints to their min/max range for runtime checks. + # This is because we assume a graph generated with N=2 is general enough + # for N < 2. Therefore, it will be too strict to assert N=2 at runtime. + self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {} self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {} self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {} # Maps symbolic ints to the guards that refine their lower/upper diff --git a/torch/fx/node.py b/torch/fx/node.py index 023e5761b60c0c..09a3b2c2b7e60d 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -36,6 +36,7 @@ _ops.aten._assert_async.msg, _ops.aten.copy_.default, _ops.aten.sym_constrain_range.default, + _ops.aten.sym_constrain_range_for_size.default, _ops.profiler._record_function_enter, _ops.profiler._record_function_enter_new, _ops.profiler._record_function_exit} diff --git a/torch/overrides.py b/torch/overrides.py index 7eee2f4f184a44..cdaa5ad4ce4a3b 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -186,6 +186,7 @@ def get_ignored_functions() -> Set[Callable]: torch.sym_min, torch.sym_not, torch.sym_constrain_range, + torch.sym_constrain_range_for_size, torch.tril_indices, torch.triu_indices, torch.vander, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 9a0a75c023b899..afd04b5fb016d7 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -74,6 +74,7 @@ "record_stream", # no return "sparse_dim", # returns an int "sym_constrain_range", # no return + "sym_constrain_range_for_size", # no return "_nested_tensor_storage_offsets", # returns a vector of ints "_chunk_grad_outputs_efficient_attention", # returns a bool "_fused_sdp_choice", # returns an int