diff --git a/c10/core/impl/TorchDispatchModeTLS.cpp b/c10/core/impl/TorchDispatchModeTLS.cpp index f97b55464e1776..191abfb36896d8 100644 --- a/c10/core/impl/TorchDispatchModeTLS.cpp +++ b/c10/core/impl/TorchDispatchModeTLS.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include @@ -10,8 +11,23 @@ namespace impl { thread_local TorchDispatchModeTLS torchDispatchModeState; -void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr mode) { - if (torchDispatchModeState.stack_.empty()) { +bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) { + if (!torchDispatchModeState.stack_.empty()) + return true; + if (!skip_infra_modes) { + for (const auto i : c10::irange( + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + return true; + } + } + } + return false; +} + +void TorchDispatchModeTLS::push_non_infra_mode_onto_stack( + std::shared_ptr mode) { + if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, true); @@ -20,30 +36,122 @@ void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr mode) { } const std::shared_ptr TorchDispatchModeTLS::pop_stack() { - TORCH_CHECK( - !torchDispatchModeState.stack_.empty(), - "trying to pop from empty mode stack"); - std::shared_ptr out = torchDispatchModeState.stack_.back(); - torchDispatchModeState.stack_.pop_back(); - - if (torchDispatchModeState.stack_.empty()) { + std::shared_ptr out; + if (!torchDispatchModeState.stack_.empty()) { + out = torchDispatchModeState.stack_.back(); + torchDispatchModeState.stack_.pop_back(); + } else { + for (int64_t i = + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS) - 1; + i >= 0; + --i) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + out = std::move(torchDispatchModeState.infra_modes_[i].value()); + torchDispatchModeState.infra_modes_[i] = c10::nullopt; + break; + } + } + } + TORCH_CHECK(out, "trying to pop from empty mode stack"); + if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, false); } return out; } +const std::tuple, TorchDispatchModeKey> +TorchDispatchModeTLS::pop_highest_infra_mode() { + for (int64_t i = static_cast(TorchDispatchModeKey::NUM_MODE_KEYS) - 1; + i >= 0; + --i) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + auto out_mode = torchDispatchModeState.infra_modes_[i].value(); + torchDispatchModeState.infra_modes_[i] = c10::nullopt; + if (!any_modes_set()) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, false); + } + return std::make_tuple( + std::move(out_mode), static_cast(i)); + } + } + TORCH_CHECK( + false, "Called pop_highest_infra_mode, but no infra modes were active.") +} const std::shared_ptr& TorchDispatchModeTLS::get_stack_at( int64_t idx) { - TORCH_CHECK( - idx < static_cast(torchDispatchModeState.stack_.size()), - "Tried to get stack at idx that's too big"); - return torchDispatchModeState.stack_[idx]; + TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big"); + // Our "logical" stack includes both: + // - any user modes (the entire torchDispatchModeState.stack_) + // - any infra modes (members of torchDispatchModeState.infra_modes_ that are + // not None) + + // idx == 0 means the "bottom" of the stack, which starts with any infra + // modes (iterating from lowest-priority to highest-priority). + auto curr_idx = idx; + for (const auto i : + c10::irange(static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + if (curr_idx == 0) { + return torchDispatchModeState.infra_modes_[i].value(); + } + curr_idx -= 1; + } + } + // At this point, we're guaranteed that curr_idx < stack_.size() + return torchDispatchModeState.stack_[curr_idx]; } int64_t TorchDispatchModeTLS::stack_len() { - return static_cast(torchDispatchModeState.stack_.size()); + auto stack_len = static_cast(torchDispatchModeState.stack_.size()); + int64_t infra_modes_len = 0; + for (const auto i : + c10::irange(static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + infra_modes_len += 1; + } + } + return stack_len + infra_modes_len; +} + +const c10::optional> TorchDispatchModeTLS:: + get_mode(TorchDispatchModeKey mode_key) { + return torchDispatchModeState.infra_modes_[static_cast(mode_key)]; +} + +void TorchDispatchModeTLS::set_mode( + const std::shared_ptr& mode, + TorchDispatchModeKey mode_key) { + TORCH_CHECK( + torchDispatchModeState.infra_modes_[static_cast(mode_key)] == + c10::nullopt, + "trying to set the current ", + to_string(mode_key), + ", but one already exists"); + + if (!any_modes_set()) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, true); + } + + torchDispatchModeState.infra_modes_[static_cast(mode_key)] = mode; +} + +const c10::optional> TorchDispatchModeTLS:: + unset_mode(TorchDispatchModeKey mode_key) { + auto out = torchDispatchModeState.infra_modes_[static_cast(mode_key)]; + torchDispatchModeState.infra_modes_[static_cast(mode_key)] = + c10::nullopt; + if (out.has_value() && !any_modes_set()) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, false); + } + return out; } const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() { @@ -52,7 +160,7 @@ const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() { void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) { torchDispatchModeState = std::move(state); - if (torchDispatchModeState.stack_.empty()) { + if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, false); @@ -67,7 +175,18 @@ void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) { bool dispatch_mode_enabled() { return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) && - TorchDispatchModeTLS::stack_len() > 0; + TorchDispatchModeTLS::any_modes_set(); +} + +std::string to_string(TorchDispatchModeKey mode_key) { + switch (mode_key) { + case TorchDispatchModeKey::PROXY: + return "ProxyTorchDispatchMode"; + case TorchDispatchModeKey::FAKE: + return "FakeTensorMode"; + default: + return "UNKNOWN_MODE"; + } } } // namespace impl diff --git a/c10/core/impl/TorchDispatchModeTLS.h b/c10/core/impl/TorchDispatchModeTLS.h index 395f50d09ad737..44eb70a202434b 100644 --- a/c10/core/impl/TorchDispatchModeTLS.h +++ b/c10/core/impl/TorchDispatchModeTLS.h @@ -6,20 +6,55 @@ namespace c10 { namespace impl { +enum class TorchDispatchModeKey : int8_t { FAKE, PROXY, NUM_MODE_KEYS }; + struct C10_API TorchDispatchModeTLS { - static void push_onto_stack(std::shared_ptr mode); + // This API is NOT invariant safe. + // It must not take in an infra mode that uses TorchDispatchModeKey + // If you're pushing an infra mode onto the stack, we expect + // you to use set_mode + static void push_non_infra_mode_onto_stack( + std::shared_ptr mode); + // Pops the top mode of the stack, + // giving precedence to user modes before attempting to pop + // any infra modes static const std::shared_ptr pop_stack(); + // Returns the highest-priority infra mode on the stack, + // along with its mode key. + static const std::tuple, TorchDispatchModeKey> + pop_highest_infra_mode(); + static const std::shared_ptr& get_stack_at(int64_t idx); static int64_t stack_len(); + static const c10::optional> get_mode( + TorchDispatchModeKey mode_key); + static const c10::optional> unset_mode( + TorchDispatchModeKey mode_key); + static void set_mode( + const std::shared_ptr& mode, + TorchDispatchModeKey mode_key); + static const TorchDispatchModeTLS& get_state(); static void set_state(TorchDispatchModeTLS state); + static bool any_modes_set(bool skip_infra_modes = false); + private: std::vector> stack_; + // Users are allowed to push multiple ProxyTorchDispatchMode objects onto the + // stack + // However, we only allow a single FakeTensorMode onto the stack at a time + // (Pushing additional FakeTensorModes onto the stack is a no-op) + std::array< + c10::optional>, + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS)> + infra_modes_; }; C10_API bool dispatch_mode_enabled(); +C10_API std::string to_string(TorchDispatchModeKey mode_key); + } // namespace impl } // namespace c10 diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 9a66b0819ed8e0..5dd19d3cd2f48a 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -1515,6 +1515,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): "test_view_clone_view_inplace", "test_view_inplace", ]) +@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well") class TestCrossRefFunctionalization(TestFunctionalization): crossref = True diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 8f40e3e79a0de1..22df7ae669b0e8 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -13,6 +13,7 @@ from torch.utils._mode_utils import no_dispatch, all_same_mode from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \ log_input, capture_logs, capture_logs_with_logging_tensor_mode +from torch.testing._internal.two_tensor import TwoTensor from torch.utils._pytree import tree_map, tree_map_only from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack from torch._custom_op.functional import register_functional_op @@ -562,7 +563,6 @@ def test_register_fallthrough(self): # default behavior should have been restored self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16) - class TestPythonDispatch(TestCase): def test_basic(self) -> None: with capture_logs() as logs: @@ -894,6 +894,36 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor) self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor) + def test_make_fx_with_subclass(self) -> None: + def f(x, y): + # Returns (TwoTensor, Tensor) + return x * y, y + y + x_a = torch.zeros(4) + x_b = torch.zeros(4) + y = torch.ones(4) + + # make_fx() is not responsible for unwrapping tensor subclass inputs, + # so we do it manually here. + # Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling + # convention as f(*args). Unwrapping tensor subclass inputs can potentially change + # the number of input args to the graph, breaking that assumption + def f_to_trace(x_a, x_b, y): + x = TwoTensor(x_a, x_b) + out1, out2 = f(x, y) + out1_unwrapped_attrs, _ = out1.__tensor_flatten__() + return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2) + fx_g = make_fx(f_to_trace, tracing_mode='fake')(x_a, x_b, y) + self.assertExpectedInline(fx_g.code, """\ + + + +def forward(self, x_a_1, x_b_1, y_1): + mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None + mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None + add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None + return (mul, mul_1, add) + """) + def test_make_wrapper_subclass_propagates_metadata(self) -> None: class WrapperTensor(torch.Tensor): elem: torch.Tensor diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index e725ded05da8bf..eb8f5da2e14f70 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -10,7 +10,7 @@ ) from torchgen.gen import parse_native_yaml, parse_tags_yaml -from torchgen.model import DispatchKey, Variant +from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant from torchgen.utils import FileManager from tools.autograd.gen_python_functions import ( @@ -1227,6 +1227,9 @@ def replace_special_case(hint: str) -> str: # Dispatch key hints # ~~~~~~~~~~~~~~~~~~ dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey] + torch_dispatch_mode_key_hints = [ + f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey + ] # Tags Enum type hints # ~~~~~~~~~~~~~~~~~~~~ @@ -1247,6 +1250,7 @@ def replace_special_case(hint: str) -> str: "legacy_storage_base_hints": legacy_storage_base_hints, "dtype_class_hints": dtype_class_hints, "dispatch_key_hints": dispatch_key_hints, + "torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints, "all_directive": all_directive, "tag_attributes": tag_attributes, } diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e7073296bb81e5..9987e7a52da8f6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1181,7 +1181,8 @@ def _get_function_stack_at(idx: _int) -> Any: ... def _len_torch_function_stack() -> _int: ... def _set_torch_dispatch_mode(cls: Any) -> None: ... def _push_on_torch_dispatch_stack(cls: Any) -> None: ... -def _pop_torch_dispatch_stack() -> Any: ... +def _pop_torch_dispatch_stack(mode_key: Optional[torch._C._TorchDispatchModeKey] = None) -> Any: ... +def _get_dispatch_mode(mode_key: Optional[torch._C._TorchDispatchModeKey]) -> Any: ... def _get_dispatch_stack_at(idx: _int) -> Any: ... def _len_torch_dispatch_stack() -> _int: ... @@ -1432,6 +1433,9 @@ def _are_functorch_transforms_active() -> _bool: ... # Define in torch/csrc/autograd/init.cpp def _set_python_dispatcher(dispatcher: object) -> None: ... +class _TorchDispatchModeKey(Enum): + ${torch_dispatch_mode_key_hints} + # Defined in torch/csrc/utils/init.cpp class BenchmarkConfig: num_calling_threads: _int diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 428f4dc6773799..338fc6685fc678 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -13,6 +13,7 @@ import torch import torch._custom_op import torch._logging + from torch._guards import Source from torch._ops import OpOverload from torch._prims_common import ( @@ -34,7 +35,6 @@ from torch.overrides import TorchFunctionMode from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import ( - _get_current_dispatch_mode_stack, is_traceable_wrapper_subclass, TorchDispatchMode, ) @@ -156,6 +156,16 @@ def contains_tensor_types(type): ) +@contextlib.contextmanager +def unset_fake_temporarily(): + old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + try: + yield old + finally: + if old is not None: + torch._C._set_dispatch_mode(old) + + @functools.lru_cache(None) def _is_tensor_constructor(func: OpOverload): assert isinstance(func, OpOverload) @@ -970,6 +980,10 @@ class FakeTensor(torch.Tensor): _nonzero_memo: Optional[torch.SymInt] _nonzero_memo_vc: Optional[int] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FAKE + @property def nonzero_memo(self): if self._nonzero_memo is None: @@ -1110,15 +1124,14 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # unluckily attempted to hit FakeTensor's dispatch first, # NotImplemented lets us keep chaining until we find the actual # subclass - cur_stack = _get_current_dispatch_mode_stack() - # NB: This must test for ANY fake tensor mode, because we want the - # fake tensor mode to take precedence - fake_modes_on_stack = [m for m in cur_stack if isinstance(m, FakeTensorMode)] - if fake_modes_on_stack: + maybe_cur_fake_mode = torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FAKE + ) + if maybe_cur_fake_mode: not_implemented_log.debug( "FakeTensor mode already active: %s in %s", - fake_modes_on_stack, - cur_stack, + fake_mode, + maybe_cur_fake_mode, ) return NotImplemented @@ -1235,12 +1248,18 @@ def __init__( # True if we enter'ed and actually enabled fake tensor mode, # false if it was a no-op. Not thread safe but neither is # in_kernel_invocation - self.enter_stack: List[bool] = [] + # If another fake mode was already active when we enter, we also stash it here. + # That way when we exit, we know to re-enable the previous fake mode. + self.enter_stack: List[Tuple[bool, Optional[FakeTensorMode]]] = [] self.shape_env = shape_env self.stack = "".join(traceback.format_stack()) + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FAKE + # Typically, there is only one fake tensor mode and you test for it by # doing an isinstance test. However, in some situations, there might be # TWO fake tensor modes. The canonical example of this is exporting @@ -1258,27 +1277,35 @@ def is_our_fake(self, t): @count def __torch_dispatch__(self, func, types, args=(), kwargs=None): - assert self not in _get_current_dispatch_mode_stack(), func + # FakeTensorMode should not be set when we're inside of it. + assert ( + torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None + ), func try: return self.dispatch(func, types, args, kwargs) except TypeError: log.exception("fake tensor raised TypeError") raise - # No-op if FakeTensorMode is already on the stack + # No-op if FakeTensorMode is already in use def __enter__(self): - if self not in _get_current_dispatch_mode_stack(): - self.enter_stack.append(True) + maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key) + if self is not maybe_prev_fake_mode: + self.enter_stack.append((True, maybe_prev_fake_mode)) return super().__enter__() else: - # no-op - self.enter_stack.append(False) - return self + # no-op (still need to re-set the fake mode though since we unset it) + torch._C._set_dispatch_mode(self) + self.enter_stack.append((False, None)) + return self def __exit__(self, a, b, c): - live = self.enter_stack.pop() + live, maybe_prev_fake_mode = self.enter_stack.pop() if live: - return super().__exit__(a, b, c) + out = super().__exit__(a, b, c) + # Re-enable the previous fake mode, if there was one. + if maybe_prev_fake_mode is not None: + torch._C._set_dispatch_mode(maybe_prev_fake_mode) def dispatch(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 746695eb714e7a..019fb21e5e6c47 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -948,9 +948,22 @@ static PyObject* push_on_torch_dispatch_stack( PyObject* arg) { HANDLE_TH_ERRORS if (arg != Py_None) { + using c10::impl::TorchDispatchModeKey; + // When we push a mode onto the mode stack, we need to + // check if it's an "infra" mode, by checking its _mode_key attribute. + c10::optional mode_key = c10::nullopt; + py::object maybe_mode_key_obj = + PyObject_FastGetAttrString(arg, "_mode_key"); + if (maybe_mode_key_obj) { + mode_key = py::cast(maybe_mode_key_obj); + c10::impl::TorchDispatchModeTLS::set_mode( + std::make_shared(arg, getPyInterpreter()), + mode_key.value()); + } else { + c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack( + std::make_shared(arg, getPyInterpreter())); + } Py_INCREF(arg); - c10::impl::TorchDispatchModeTLS::push_onto_stack( - std::make_shared(arg, getPyInterpreter())); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -958,10 +971,25 @@ static PyObject* push_on_torch_dispatch_stack( static PyObject* pop_torch_dispatch_stack( PyObject* _unused, - PyObject* _unused2) { + PyObject* maybe_mode_key) { HANDLE_TH_ERRORS - const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack(); - auto* r = mode->ptr(getPyInterpreter()); + c10::optional mode_key = c10::nullopt; + PyObject* r; + if (maybe_mode_key != Py_None) { + mode_key = py::cast(maybe_mode_key); + auto maybe_mode = + c10::impl::TorchDispatchModeTLS::unset_mode(mode_key.value()); + TORCH_CHECK( + maybe_mode.has_value(), + "Attempted to unset ", + c10::impl::to_string(mode_key.value()), + ", but there wasn't one active."); + auto mode = maybe_mode.value(); + r = mode->ptr(getPyInterpreter()); + } else { + auto mode = c10::impl::TorchDispatchModeTLS::pop_stack(); + r = mode->ptr(getPyInterpreter()); + } Py_INCREF(r); return r; END_HANDLE_TH_ERRORS @@ -985,9 +1013,55 @@ static PyObject* get_dispatch_stack_at( END_HANDLE_TH_ERRORS } -static PyObject* len_torch_dispatch_stack( - PyObject* _unused, - PyObject* _unused2) { +static PyObject* set_dispatch_mode(PyObject* _unused, PyObject* mode) { + HANDLE_TH_ERRORS + TORCH_CHECK(mode != Py_None); + + py::object maybe_mode_key_obj = PyObject_FastGetAttrString(mode, "_mode_key"); + TORCH_CHECK( + maybe_mode_key_obj, + "set_dispatch_mode() called with a mode that does not contain a _mode_key attribute!"); + auto mode_key = py::cast(maybe_mode_key_obj); + + Py_INCREF(mode); + c10::impl::TorchDispatchModeTLS::set_mode( + std::make_shared(mode, getPyInterpreter()), mode_key); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(arg != Py_None); + auto mode_key = py::cast(arg); + + auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(mode_key); + if (maybe_mode == c10::nullopt) { + Py_RETURN_NONE; + } + auto* r = maybe_mode.value()->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(arg != Py_None); + auto mode_key = py::cast(arg); + + const auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key); + if (maybe_mode == c10::nullopt) { + Py_RETURN_NONE; + } + auto* r = maybe_mode.value()->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* len_torch_dispatch_stack(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS const auto len = c10::impl::TorchDispatchModeTLS::stack_len(); return utils::wrap(static_cast(len)); @@ -1099,10 +1173,7 @@ static PyMethodDef methods[] = { // NOLINT push_on_torch_dispatch_stack, METH_O, nullptr}, - {"_pop_torch_dispatch_stack", - pop_torch_dispatch_stack, - METH_NOARGS, - nullptr}, + {"_pop_torch_dispatch_stack", pop_torch_dispatch_stack, METH_O, nullptr}, {"_get_dispatch_stack_at", castPyCFunctionWithKeywords(get_dispatch_stack_at), METH_VARARGS | METH_KEYWORDS, @@ -1111,6 +1182,10 @@ static PyMethodDef methods[] = { // NOLINT len_torch_dispatch_stack, METH_NOARGS, nullptr}, + {"_set_dispatch_mode", set_dispatch_mode, METH_O, nullptr}, + {"_get_dispatch_mode", get_dispatch_mode, METH_O, nullptr}, + {"_unset_dispatch_mode", unset_dispatch_mode, METH_O, nullptr}, + {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index ee64a1d8879f1e..f013f62bea2d65 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -258,6 +258,129 @@ static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) { return (PyObject*)Py_TYPE(obj_or_type); } +static py::object dispatch_on_subclass( + PyObject* args, + PyObject* kwargs, + at::ArrayRef overloaded_args, + py::tuple py_types, + PyObject* torch_api_function, + bool is_torch_function, + const char* torch_function_name_str, + c10::optional maybe_mode_key = + c10::nullopt) { + py::object ret; + for (auto& arg : overloaded_args) { + py::object torch_function = + PyObject_FastGetAttrString(arg, torch_function_name_str); + if (!torch_function) { + TORCH_INTERNAL_ASSERT(0); + } + if (torch_function.ptr() == torch::disabled_torch_dispatch_impl()) { + // During __torch_dispatch__, don't dispatch on args with a disabled + // torch_dispatch. This code runs before infra modes, so we need to make + // sure that infra modes can run first. (In theory, maybe we can rearrange + // things so that infra modes are *always* attempted first, and just + // return NotImplemented when there are any user subclasses. Maybe that + // would fix this problem?) + continue; + } + + // See https://github.com/pytorch/pytorch/issues/63767 + if (is_torch_function && + PyObject_FastGetAttrString(torch_function.ptr(), "__self__") + .is(py::handle(arg)) && + torch_function.ptr() != torch::disabled_torch_function_impl()) { + TORCH_WARN( + "Defining your `", + torch_function_name_str, + "` as a plain method is deprecated ", + "and will be an error in future, please define it as a classmethod."); + } + + ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( + torch_function.ptr(), + torch_api_function, + py_types.ptr(), + args, + kwargs, + NULL)); + if (ret.ptr() == nullptr) { + throw python_error(); + } + if (ret.ptr() != Py_NotImplemented) { + // Return the reference to the result. This also covers the case where + // ret is NULL and __torch_function__/__torch_dispatch raised an + // exception, which we throw below + break; + } + } + return ret; +} + +static std::tuple dispatch_on_mode( + PyObject* args, + PyObject* kwargs, + py::tuple py_types, + PyObject* torch_api_function, + bool is_torch_function, + const char* torch_function_name_str) { + // Disable mode on the inside; this makes for a more user-friendly + // experience if you try to, e.g., print your tensors. + at::optional tf_g; + at::optional td_g; + py::object mode_obj; + // NB: We only really need keep the mode_obj live if the function call + // fails for error reporting, but whatever, Python refcounts are cheap + if (is_torch_function) { + tf_g.emplace(); + mode_obj = py::reinterpret_borrow( + tf_g->get_cur_mode()->ptr(getPyInterpreter())); + } else { + td_g.emplace(); + mode_obj = py::reinterpret_borrow( + td_g->get_cur_mode()->ptr(getPyInterpreter())); + } + py::object torch_function = + PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str); + if (!torch_function) { + TORCH_INTERNAL_ASSERT(0); + } + TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr); + TORCH_INTERNAL_ASSERT(args != nullptr); + + TORCH_CHECK( + PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(mode_obj), + "Defining your mode's `", + torch_function_name_str, + "` as a classmethod is not supported, please make it a plain method"); + + // Blegh. This accidentally works in PyObject_CallFunctionObjArgs below + // because the nullptr terminates the argument list ick ick ick. + py::object ret; + if (kwargs == nullptr) { + ret = py::reinterpret_steal(PyObject_CallMethod( + mode_obj.ptr(), + torch_function_name_str, + "OOO", + torch_api_function, + py_types.ptr(), + args)); + } else { + ret = py::reinterpret_steal(PyObject_CallMethod( + mode_obj.ptr(), + torch_function_name_str, + "OOOO", + torch_api_function, + py_types.ptr(), + args, + kwargs)); + } + if (ret.ptr() == nullptr) { + throw python_error(); + } + return std::make_tuple(ret, mode_obj); +} + // See Note: [Overloaded args] for what they hold auto handle_torch_function_no_python_arg_parser( at::ArrayRef overloaded_args, @@ -291,102 +414,89 @@ auto handle_torch_function_no_python_arg_parser( py::object ret; py::object mode_obj; + // Step 1: Try to dispatch based on the mode stack, *ignoring* infra + // torch_dispatch modes. const bool is_torch_function = torch_function_name == TorchFunctionName::TorchFunction; const auto is_mode_active = [&]() { - return is_torch_function ? at::impl::torch_function_mode_enabled() - : c10::impl::dispatch_mode_enabled(); + return is_torch_function + ? at::impl::torch_function_mode_enabled() + // Check if any *user* torch_dispatch modes are active (not including + // fake and proxy modes, which are special) + : c10::impl::dispatch_mode_enabled(); }; + // Note [__torch_dispatch__ dispatching order] + // The high-level idea motivating the dispatching + // order below is that: (1) modes get higher dispatch precedence over + // subclasses (2) "user" modes/subclasses get higher dispatch precedence over + // "infra" modes/subclasses. + // + // To give a complete example: let's say we are running torch.compile, with + // the following "user" modes and subclasses: + // mode_stack: [ModeA] + // user_args: [MyWrapperSubclassB(torchTensor)] + + // During tracing in AOTAutograd tracing, we use some additional infra modes + // and subclasses to perform tracing: + // FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode, + // FunctionalTensor, FakeTensor + // The modified mode stack and tracing arguments will look like this: + // mode_stack (user modes): [ModeA] + // mode_stack (infra modes): [ + // FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode + // ] + // tracing_args: [ + // MyWrapperSubclassB(FunctionalTensor(_to_functional_tensor(FakeTensor))) + // ] + + // And the dispatching order that we want is as follows: + // (1) ModeA.__torch_dispatch__ (user modes highest) + // (2) MyWrapperSubclassB.__torch_dispatch__ (user subclasses next highest) + // (3) FunctionalTensorMode.__torch_dispatch__ (infra modes next highest) + // (4) ProxyTorchDispatchMode.__torch_dispatch__ (infra modes next highest) + // (5) FakeTensorMode.__torch_dispatch__ (infra modes next highest) + // (6) FakeTensor.__torch_fake_dispatch__ (infra subclasses next highest) + + // Why does do FunctionalTensor and FakeTensor even need to be special-cased + // in the ordering? + // In theory we could remove their __torch_dispatch__, but both of these + // subclasses override sizes/strides metadata calls with __torch_dispatch__, + // which would mean a mode would be **required** to access their metadata. if (is_mode_active()) { - // Disable mode on the inside; this makes for a more user-friendly - // experience if you try to, e.g., print your tensors. - at::optional tf_g; - at::optional td_g; - // NB: We only really need keep the mode_obj live if the function call - // fails for error reporting, but whatever, Python refcounts are cheap - if (is_torch_function) { - tf_g.emplace(); - mode_obj = py::reinterpret_borrow( - tf_g->get_cur_mode()->ptr(getPyInterpreter())); - } else { - td_g.emplace(); - mode_obj = py::reinterpret_borrow( - td_g->get_cur_mode()->ptr(getPyInterpreter())); - } - py::object torch_function = - PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str); - if (!torch_function) { - TORCH_INTERNAL_ASSERT(0); - } - TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr); - TORCH_INTERNAL_ASSERT(args != nullptr); - - TORCH_CHECK( - PyObject_FastGetAttrString(torch_function.ptr(), "__self__") - .is(mode_obj), - "Defining your mode's `", - torch_function_name_str, - "` as a classmethod is not supported, please make it a plain method"); - - // Blegh. This accidentally works in PyObject_CallFunctionObjArgs below - // because the nullptr terminates the argument list ick ick ick. - if (kwargs == nullptr) { - ret = py::reinterpret_steal(PyObject_CallMethod( - mode_obj.ptr(), - torch_function_name_str, - "OOO", - torch_api_function, - py_types.ptr(), - args)); - } else { - ret = py::reinterpret_steal(PyObject_CallMethod( - mode_obj.ptr(), - torch_function_name_str, - "OOOO", - torch_api_function, - py_types.ptr(), - args, - kwargs)); - } - if (ret.ptr() == nullptr) { - throw python_error(); - } - } + // Step 1: Try to dispatch on any user TorchDispatchModes (including infra + // modes, which will always be at the bottom of the mode stack). + auto ret_ = dispatch_on_mode( + args, + kwargs, + py_types, + torch_api_function, + is_torch_function, + torch_function_name_str); + ret = std::get<0>(ret_); + mode_obj = std::get<1>(ret_); + } + + // Step 2: Try to dispatch based on any user subclasses, + // ignoring any subclasses that have a _mode_key field + // (corresponding to infra subclasses) + // Note: user subclasses should always run *before* infra modes like + // proxy/fake. This is handles by having proxy/fake modes return + // NotImplemented when they see a user subclass that they don't understand. if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) { - for (auto& arg : overloaded_args) { - py::object torch_function = - PyObject_FastGetAttrString(arg, torch_function_name_str); - if (!torch_function) { - TORCH_INTERNAL_ASSERT(0); - } - - // See https://github.com/pytorch/pytorch/issues/63767 - if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__") - .is(py::handle(arg)) && - torch_function.ptr() != torch::disabled_torch_function_impl()) { - TORCH_WARN( - "Defining your `", - torch_function_name_str, - "` as a plain method is deprecated ", - "and will be an error in future, please define it as a classmethod."); - } - - ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( - torch_function.ptr(), - torch_api_function, - py_types.ptr(), - args, - kwargs, - NULL)); - if (ret.ptr() != Py_NotImplemented) { - // Return the reference to the result. This also covers the case where - // ret is NULL and __torch_function__/__torch_dispatch raised an - // exception, which we throw below - break; - } + auto curr_ret = dispatch_on_subclass( + args, + kwargs, + overloaded_args, + py_types, + torch_api_function, + is_torch_function, + torch_function_name_str); + if (curr_ret.ptr() != nullptr) { + ret = curr_ret; } } + if (ret.ptr() == nullptr) { // if an exception occurred in a user's implementation of // __torch_function__, throw it diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 70629e59edd089..3216f1ff7246d6 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -744,6 +744,11 @@ void initDispatchBindings(PyObject* module) { return c10::SymInt( c10::SymNode(c10::make_intrusive(data))); }); + + using c10::impl::TorchDispatchModeKey; + py::enum_(m, "_TorchDispatchModeKey") + .value("PROXY", TorchDispatchModeKey::PROXY) + .value("FAKE", TorchDispatchModeKey::FAKE); } // TODO: dedupe with the kernel diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index 1b36569bdf66cd..78a7dbb8167369 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -8,11 +8,25 @@ namespace torch_dispatch_mode { struct StashTorchDispatchModeGuard { public: StashTorchDispatchModeGuard() { - saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack(); + if (c10::impl::TorchDispatchModeTLS::any_modes_set( + /*skip_infra_modes=*/true)) { + saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack(); + } else { + auto mode_and_key = + c10::impl::TorchDispatchModeTLS::pop_highest_infra_mode(); + saved_mode_ = std::move(std::get<0>(mode_and_key)); + saved_mode_key_ = std::get<1>(mode_and_key); + } } ~StashTorchDispatchModeGuard() { - c10::impl::TorchDispatchModeTLS::push_onto_stack(std::move(saved_mode_)); + if (saved_mode_key_ != c10::nullopt) { + c10::impl::TorchDispatchModeTLS::set_mode( + std::move(saved_mode_), saved_mode_key_.value()); + } else { + c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack( + std::move(saved_mode_)); + } } const std::shared_ptr& get_cur_mode() { @@ -21,6 +35,7 @@ struct StashTorchDispatchModeGuard { private: std::shared_ptr saved_mode_; + c10::optional saved_mode_key_; }; struct StashTorchDispatchStackGuard { diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index a928bdc8a7ed5e..e818ba0019287d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -9,7 +9,7 @@ import torch import torch.utils._pytree as pytree from torch.fx import Tracer, GraphModule -from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch import torch.fx as fx from torch.fx.passes.shape_prop import _extract_tensor_metadata @@ -25,11 +25,10 @@ from torch.utils._python_dispatch import ( TorchDispatchMode, - _pop_mode_temporarily, - _get_current_dispatch_mode, + _pop_mode, + _push_mode, ) -from torch._subclasses import FakeTensor from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode from torch.fx import Proxy import torch.fx.traceback as fx_traceback @@ -115,7 +114,7 @@ def snapshot_fake(val): return val.detach() def extract_val(val): - if isinstance(val, FakeTensor): + if is_fake(val): return snapshot_fake(val) elif isinstance(val, py_sym_types): return val @@ -144,7 +143,7 @@ def extract_val(val): def set_meta(proxy, val): proxy.node.meta['val'] = extract_val(val) # Best effort tensor_meta setting; prefer using val! - if isinstance(val, FakeTensor): + if is_fake(val): proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) elif isinstance(val, torch.Tensor) and not val.is_sparse: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) @@ -211,11 +210,7 @@ def get_constant(idx): def maybe_disable_fake_tensor_mode(): # TODO: figure out if this API generally makes sense and bake it into the # library - mb_fake_mode = _get_current_dispatch_mode() - if isinstance(mb_fake_mode, FakeTensorMode): - return _pop_mode_temporarily() - else: - return nullcontext() + return unset_fake_temporarily() @dataclass @@ -294,7 +289,7 @@ def can_handle_tensor(x): # If any of the Tensor inputs are "real" (not FakeTensor), we may # incorrectly burn in constants by allowing this access. Raise # an error in this case - if pytree.tree_all_only(torch.Tensor, lambda t: not isinstance(t, FakeTensor), (args, kwargs)): + if pytree.tree_all_only(torch.Tensor, lambda t: not is_fake(t), (args, kwargs)): raise RuntimeError( f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " "It's likely that this is caused by data-dependent control flow or similar. " @@ -470,6 +465,25 @@ def dispatch_trace( return GraphModule(tracer.root, graph, name) +@contextlib.contextmanager +def _pop_proxy_mode_temporarily(dk): + # This is a shim around the existng per-dispatch-key-mode logic. + # I'll delete the per-dispatch-key-mode logic in a followup PR + if dk is not None: + # During pre_dispatch, pop off of the PreDispatch mode stack + old = _pop_mode(dk) + try: + yield old + finally: + _push_mode(old, dk) + else: + # During normal tracing, pop off of the dedicated proxy mode stack + old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + try: + yield old + finally: + torch._C._set_dispatch_mode(old) + def wrap_key(f, tensors, tracer, pre_dispatch: bool): flat_tensors, tensors_spec = pytree.tree_flatten(tensors) dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None @@ -478,7 +492,7 @@ def wrap_key(f, tensors, tracer, pre_dispatch: bool): def wrapped(*proxies): flat_proxies, proxies_spec = pytree.tree_flatten(proxies) assert len(flat_proxies) == len(flat_tensors) - with _pop_mode_temporarily(dk) as m: + with _pop_proxy_mode_temporarily(dk) as m: assert isinstance(m, ProxyTorchDispatchMode) track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) @@ -548,6 +562,13 @@ def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constan self.sym_mode = ProxySymDispatchMode(tracer) self.trace_state = {} self._managers = [] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.PROXY + # Every time we enter a mode, we maintain a stack telling us what the previous + # ProxyTorchDispatchMode state was (if there was any). + # This lets us properly reset the state on exit. + self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = [] @count def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -559,17 +580,27 @@ def __enter__(self): m = self.sym_mode.enable(True) self._managers.append(m) m.__enter__() + # Stash and store the previous proxy mode (there may or may not be one) + maybe_prev_proxy_mode = torch._C._unset_dispatch_mode(self._mode_key) + self.enter_stack.append(maybe_prev_proxy_mode) return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): m = self._managers.pop() # ...exit us first, then sym mode b = super().__exit__(exc_type, exc_value, traceback) + + # Re-enable the previous proxy mode, if there was one. + mb_previous_proxy_mode = self.enter_stack.pop() + if mb_previous_proxy_mode is not None: + torch._C._set_dispatch_mode(mb_previous_proxy_mode) + if not b: return m.__exit__(exc_type, exc_value, traceback) else: return m.__exit__(None, None, None) + def inner_torch_dispatch(self, func, types, args=(), kwargs=None): if not self.enable_tracing: return func(*args, **kwargs) @@ -805,7 +836,7 @@ def wrap_fake(x): # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is # thus irrelevant to any external functional trace. with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, proxy_function_mode, \ - sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True): + sym_mode, proxy_mode, disable_autocast_cache(): t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs)) # TODO: kind of a bad way to do it, should maybe figure out a better way @@ -821,28 +852,24 @@ def get_torch_dispatch_modes(): def get_innermost_proxy_mode(): - for m in reversed(torch.utils._python_dispatch._get_current_dispatch_mode_stack()): - if isinstance(m, ProxyTorchDispatchMode): - return m - return None + return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) @contextlib.contextmanager def disable_proxy_modes_tracing(enable_current=False): - modes = get_torch_dispatch_modes() - proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)] - if enable_current: - proxy_tensor_modes = proxy_tensor_modes[:-1] - olds = [(m.enable_tracing, m.sym_mode.enable_tracing) for m in proxy_tensor_modes] - for proxy_mode in proxy_tensor_modes: - proxy_mode.enable_tracing = False - proxy_mode.sym_mode.enable_tracing = False + # enable_current=True is now a no-op, since only one proxy mode + # can live on the stack at a time. + # We should kill this API in a future PR. + maybe_old = None + if not enable_current: + # Only one proxy_mode can be "active" at a time. + # So we simply remove our active mode. + maybe_old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) try: yield finally: - for proxy_mode, (old, old_sym) in zip(proxy_tensor_modes, olds): - proxy_mode.enable_tracing = old - proxy_mode.sym_mode.enable_tracing = old_sym + if maybe_old is not None: + torch._C._set_dispatch_mode(maybe_old) def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"): diff --git a/torch/testing/_internal/two_tensor.py b/torch/testing/_internal/two_tensor.py new file mode 100644 index 00000000000000..6add6b33d6f86c --- /dev/null +++ b/torch/testing/_internal/two_tensor.py @@ -0,0 +1,57 @@ +import torch + + +# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors. +class TwoTensor(torch.Tensor): + @staticmethod + def __new__(cls, a, b): + assert ( + a.device == b.device + and a.layout == b.layout + and a.requires_grad == b.requires_grad + and a.dtype == b.dtype + ) + # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape + shape = a.shape + kwargs = {} + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, a, b): + self.a = a + self.b = b + + def __repr__(self): + a_repr = repr(self.a) + b_repr = repr(self.b) + return f"TwoTensor({a_repr}, {b_repr})" + + def __tensor_flatten__(self): + return ["a", "b"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta): + assert meta is None + a, b = inner_tensors["a"], inner_tensors["b"] + return TwoTensor(a, b) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + assert any(isinstance(x, TwoTensor) for x in args) + assert any(isinstance(x, TwoTensor) for x in args) + args_a = [x.a if isinstance(x, TwoTensor) else x for x in args] + args_b = [x.b if isinstance(x, TwoTensor) else x for x in args] + out_a = func(*args_a, **kwargs) + out_b = func(*args_b, **kwargs) + assert type(out_a) == type(out_b) + if isinstance(out_a, torch.Tensor): + return TwoTensor(out_a, out_b) + # for aten ops that return non-tensors, just assume that + # our two inner tensors return the same value + assert out_a == out_b + return out_a diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 029c614aa2fc28..8052e07b4f03b8 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,5 +1,5 @@ import contextlib -from typing import Optional +from typing import Optional, Union import warnings import torch @@ -56,7 +56,12 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - _pop_mode(self.__dict__.get("_dispatch_key", None)) + mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None) + if mb_dk_or_mode_key is None: + # Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch) + # We should probably revisit this. + mb_dk_or_mode_key = self.__dict__.get("_mode_key", None) + _pop_mode(mb_dk_or_mode_key) @classmethod def push(cls, *args, **kwargs): @@ -66,7 +71,10 @@ def push(cls, *args, **kwargs): def _get_current_dispatch_mode(): stack_len = _len_torch_dispatch_stack() - return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None + # Return a user mode on the stack if there are any + if stack_len > 0: + return _get_dispatch_stack_at(stack_len - 1) + return None def _get_current_dispatch_mode_stack(): @@ -87,12 +95,14 @@ def _push_mode(mode, k: Optional[DispatchKey] = None): _push_on_torch_dispatch_stack(mode) -def _pop_mode(k: Optional[DispatchKey] = None): - if k is not None: - from torch._ops import pop_mode_for_key - return pop_mode_for_key(k) - else: - return _pop_torch_dispatch_stack() +def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None): + if k is None or isinstance(k, torch._C._TorchDispatchModeKey): + return _pop_torch_dispatch_stack(k) + from torch._ops import pop_mode_for_key + # per-dispatch-key-mode-stack do not currently handle "always running infra modes last". + # In practice this doesn't matter, since ProxyTorchDispatchMode is the only mode + # that we push onto these per-dispatch-key-mode-stacks. + return pop_mode_for_key(k) @contextlib.contextmanager @@ -108,6 +118,8 @@ def _pop_mode_temporarily(k: Optional[DispatchKey] = None): def _disable_current_modes(): mode_len = _len_torch_dispatch_stack() old_modes = [_pop_mode() for _ in range(mode_len)] + + # Manually disable proxy and fake modes, if any are active try: yield old_modes finally: @@ -121,7 +133,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = {} return func(*args, **kwargs) - def is_traceable_wrapper_subclass(t): """ Returns whether or not a tensor subclass that implements __torch_dispatch__ diff --git a/torchgen/model.py b/torchgen/model.py index 9f47ad4051a01b..543806d14e4b3c 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -205,6 +205,11 @@ def parse(value: str) -> "DispatchKey": raise AssertionError(f"unknown dispatch key {value}") +class _TorchDispatchModeKey(Enum): + FAKE = auto() + PROXY = auto() + + def codegen_per_backend_entries() -> str: r = [] for fk in FUNCTIONALITY_KEYS: