Skip to content

Commit

Permalink
[export] Refactor constrain_as_value and constrain_as_size (pytor…
Browse files Browse the repository at this point in the history
…ch#106591)

Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.

Pull Request resolved: pytorch#106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Aug 15, 2023
1 parent d6c120d commit 20c5add
Show file tree
Hide file tree
Showing 19 changed files with 389 additions and 145 deletions.
51 changes: 48 additions & 3 deletions aten/src/ATen/native/Constraints.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <limits>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
Expand All @@ -15,21 +16,65 @@
#include <ATen/ops/_make_dep_token_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/sym_constrain_range_native.h>
#include <ATen/ops/sym_constrain_range_for_size_native.h>
#include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
#endif

namespace at {
namespace native {

void sym_constrain_range_cpu(
void sym_constrain_range(
const Scalar& size,
c10::optional<int64_t> min,
c10::optional<int64_t> max) {}
c10::optional<int64_t> max) {

Tensor _functional_sym_constrain_range_cpu(
int64_t min_val = min.has_value() ? min.value() : std::numeric_limits<int64_t>::min();
int64_t max_val = max.has_value() ? max.value() : std::numeric_limits<int64_t>::max();
int64_t size_as_int = size.toInt();

TORCH_CHECK(
max_val >= min_val,
"Max must be greater than or equal to min. Got min=",
min_val,
" max=",
max_val
);

TORCH_CHECK(
min_val <= size_as_int && size_as_int <= max_val,
"Invalid value range for ",
size_as_int,
" between [",
min_val,
", ",
max_val,
"]."
);
}

Tensor _functional_sym_constrain_range(
const Scalar& size,
c10::optional<int64_t> min,
c10::optional<int64_t> max,
const Tensor& dep_token) {
sym_constrain_range(size, min, max);
return dep_token.clone();
}

void sym_constrain_range_for_size(const Scalar& size, c10::optional<int64_t> min, c10::optional<int64_t> max) {
int64_t min_val = min.has_value() ? min.value() : 0;
if (max.has_value() && max.value() <= 2) {
TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value());
}
sym_constrain_range(size, min_val, max);
}

Tensor _functional_sym_constrain_range_for_size(
const Scalar& size,
c10::optional<int64_t> min,
c10::optional<int64_t> max,
const Tensor& dep_token) {
sym_constrain_range_for_size(size, min, max);
return dep_token.clone();
}

Expand Down
15 changes: 0 additions & 15 deletions aten/src/ATen/native/cuda/Constraints.cu

This file was deleted.

15 changes: 11 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,21 @@

- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()

- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> ()
- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()
dispatch:
CPU: sym_constrain_range_cpu
CUDA: sym_constrain_range_cuda
CompositeExplicitAutograd: sym_constrain_range

- func: sym_constrain_range_for_size(Scalar size, *, int? min, int? max) -> ()
dispatch:
CompositeExplicitAutograd: sym_constrain_range_for_size

- func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
dispatch:
CPU: _functional_sym_constrain_range_cpu
CompositeExplicitAutograd: _functional_sym_constrain_range

- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
dispatch:
CompositeExplicitAutograd: _functional_sym_constrain_range_for_size

- func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:
Expand Down
10 changes: 5 additions & 5 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,7 +2273,7 @@ def my_dyn_fn(a, b, c):
def test_export_preserve_constraints_as_metadata_scalar(self):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
constrain_as_size(b)
return torch.empty((b, y.shape[0]))

x = torch.tensor([3])
Expand Down Expand Up @@ -2322,7 +2322,7 @@ def test_exported_graph_serialization(self):

def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
constrain_as_size(b)
return torch.empty((b, y.shape[0]))

x = torch.tensor([3])
Expand All @@ -2344,11 +2344,11 @@ def f(x, y):
def test_export_with_inline_constraints(self):
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
constrain_as_value(a, 4, 7)
return torch.empty((a, 4))

with self.assertRaisesRegex(
torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]"
RuntimeError, r"Invalid value range for 20 between \[4, 7\]."
) as cm:
torch._export.export(f, (torch.tensor([20]),))

Expand All @@ -2368,7 +2368,7 @@ def f(x):
def test_export_with_inline_constraints_complex(self):
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
constrain_as_value(a, 4, 7)
empty = torch.empty((a, 4))

return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)
Expand Down
1 change: 1 addition & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ aten::_foreach_zero.out
aten::_foreach_zero_
aten::_functional_assert_async.msg
aten::_functional_sym_constrain_range
aten::_functional_sym_constrain_range_for_size
aten::_fused_adam
aten::_fused_adam.out
aten::_fused_adam_
Expand Down
Loading

0 comments on commit 20c5add

Please sign in to comment.