From da54f3c51950b14dd06687752178ef59359bd5ec Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 28 Aug 2023 13:29:13 -0700 Subject: [PATCH] reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415 --- c10/core/impl/TorchDispatchModeTLS.cpp | 151 +++++++++++-- c10/core/impl/TorchDispatchModeTLS.h | 37 +++- test/test_functionalization.py | 1 + test/test_python_dispatch.py | 32 ++- tools/pyi/gen_pyi.py | 6 +- torch/_C/__init__.pyi.in | 6 +- torch/_subclasses/fake_tensor.py | 63 ++++-- torch/csrc/autograd/init.cpp | 99 +++++++-- torch/csrc/utils/python_arg_parser.cpp | 284 +++++++++++++++++-------- torch/csrc/utils/python_dispatch.cpp | 5 + torch/csrc/utils/torch_dispatch_mode.h | 19 +- torch/fx/experimental/proxy_tensor.py | 85 +++++--- torch/testing/_internal/two_tensor.py | 57 +++++ torch/utils/_python_dispatch.py | 31 ++- torchgen/model.py | 5 + 15 files changed, 703 insertions(+), 178 deletions(-) create mode 100644 torch/testing/_internal/two_tensor.py diff --git a/c10/core/impl/TorchDispatchModeTLS.cpp b/c10/core/impl/TorchDispatchModeTLS.cpp index f97b55464e1776..191abfb36896d8 100644 --- a/c10/core/impl/TorchDispatchModeTLS.cpp +++ b/c10/core/impl/TorchDispatchModeTLS.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include @@ -10,8 +11,23 @@ namespace impl { thread_local TorchDispatchModeTLS torchDispatchModeState; -void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr mode) { - if (torchDispatchModeState.stack_.empty()) { +bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) { + if (!torchDispatchModeState.stack_.empty()) + return true; + if (!skip_infra_modes) { + for (const auto i : c10::irange( + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + return true; + } + } + } + return false; +} + +void TorchDispatchModeTLS::push_non_infra_mode_onto_stack( + std::shared_ptr mode) { + if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, true); @@ -20,30 +36,122 @@ void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr mode) { } const std::shared_ptr TorchDispatchModeTLS::pop_stack() { - TORCH_CHECK( - !torchDispatchModeState.stack_.empty(), - "trying to pop from empty mode stack"); - std::shared_ptr out = torchDispatchModeState.stack_.back(); - torchDispatchModeState.stack_.pop_back(); - - if (torchDispatchModeState.stack_.empty()) { + std::shared_ptr out; + if (!torchDispatchModeState.stack_.empty()) { + out = torchDispatchModeState.stack_.back(); + torchDispatchModeState.stack_.pop_back(); + } else { + for (int64_t i = + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS) - 1; + i >= 0; + --i) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + out = std::move(torchDispatchModeState.infra_modes_[i].value()); + torchDispatchModeState.infra_modes_[i] = c10::nullopt; + break; + } + } + } + TORCH_CHECK(out, "trying to pop from empty mode stack"); + if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, false); } return out; } +const std::tuple, TorchDispatchModeKey> +TorchDispatchModeTLS::pop_highest_infra_mode() { + for (int64_t i = static_cast(TorchDispatchModeKey::NUM_MODE_KEYS) - 1; + i >= 0; + --i) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + auto out_mode = torchDispatchModeState.infra_modes_[i].value(); + torchDispatchModeState.infra_modes_[i] = c10::nullopt; + if (!any_modes_set()) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, false); + } + return std::make_tuple( + std::move(out_mode), static_cast(i)); + } + } + TORCH_CHECK( + false, "Called pop_highest_infra_mode, but no infra modes were active.") +} const std::shared_ptr& TorchDispatchModeTLS::get_stack_at( int64_t idx) { - TORCH_CHECK( - idx < static_cast(torchDispatchModeState.stack_.size()), - "Tried to get stack at idx that's too big"); - return torchDispatchModeState.stack_[idx]; + TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big"); + // Our "logical" stack includes both: + // - any user modes (the entire torchDispatchModeState.stack_) + // - any infra modes (members of torchDispatchModeState.infra_modes_ that are + // not None) + + // idx == 0 means the "bottom" of the stack, which starts with any infra + // modes (iterating from lowest-priority to highest-priority). + auto curr_idx = idx; + for (const auto i : + c10::irange(static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + if (curr_idx == 0) { + return torchDispatchModeState.infra_modes_[i].value(); + } + curr_idx -= 1; + } + } + // At this point, we're guaranteed that curr_idx < stack_.size() + return torchDispatchModeState.stack_[curr_idx]; } int64_t TorchDispatchModeTLS::stack_len() { - return static_cast(torchDispatchModeState.stack_.size()); + auto stack_len = static_cast(torchDispatchModeState.stack_.size()); + int64_t infra_modes_len = 0; + for (const auto i : + c10::irange(static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + infra_modes_len += 1; + } + } + return stack_len + infra_modes_len; +} + +const c10::optional> TorchDispatchModeTLS:: + get_mode(TorchDispatchModeKey mode_key) { + return torchDispatchModeState.infra_modes_[static_cast(mode_key)]; +} + +void TorchDispatchModeTLS::set_mode( + const std::shared_ptr& mode, + TorchDispatchModeKey mode_key) { + TORCH_CHECK( + torchDispatchModeState.infra_modes_[static_cast(mode_key)] == + c10::nullopt, + "trying to set the current ", + to_string(mode_key), + ", but one already exists"); + + if (!any_modes_set()) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, true); + } + + torchDispatchModeState.infra_modes_[static_cast(mode_key)] = mode; +} + +const c10::optional> TorchDispatchModeTLS:: + unset_mode(TorchDispatchModeKey mode_key) { + auto out = torchDispatchModeState.infra_modes_[static_cast(mode_key)]; + torchDispatchModeState.infra_modes_[static_cast(mode_key)] = + c10::nullopt; + if (out.has_value() && !any_modes_set()) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, false); + } + return out; } const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() { @@ -52,7 +160,7 @@ const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() { void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) { torchDispatchModeState = std::move(state); - if (torchDispatchModeState.stack_.empty()) { + if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, false); @@ -67,7 +175,18 @@ void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) { bool dispatch_mode_enabled() { return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) && - TorchDispatchModeTLS::stack_len() > 0; + TorchDispatchModeTLS::any_modes_set(); +} + +std::string to_string(TorchDispatchModeKey mode_key) { + switch (mode_key) { + case TorchDispatchModeKey::PROXY: + return "ProxyTorchDispatchMode"; + case TorchDispatchModeKey::FAKE: + return "FakeTensorMode"; + default: + return "UNKNOWN_MODE"; + } } } // namespace impl diff --git a/c10/core/impl/TorchDispatchModeTLS.h b/c10/core/impl/TorchDispatchModeTLS.h index 395f50d09ad737..44eb70a202434b 100644 --- a/c10/core/impl/TorchDispatchModeTLS.h +++ b/c10/core/impl/TorchDispatchModeTLS.h @@ -6,20 +6,55 @@ namespace c10 { namespace impl { +enum class TorchDispatchModeKey : int8_t { FAKE, PROXY, NUM_MODE_KEYS }; + struct C10_API TorchDispatchModeTLS { - static void push_onto_stack(std::shared_ptr mode); + // This API is NOT invariant safe. + // It must not take in an infra mode that uses TorchDispatchModeKey + // If you're pushing an infra mode onto the stack, we expect + // you to use set_mode + static void push_non_infra_mode_onto_stack( + std::shared_ptr mode); + // Pops the top mode of the stack, + // giving precedence to user modes before attempting to pop + // any infra modes static const std::shared_ptr pop_stack(); + // Returns the highest-priority infra mode on the stack, + // along with its mode key. + static const std::tuple, TorchDispatchModeKey> + pop_highest_infra_mode(); + static const std::shared_ptr& get_stack_at(int64_t idx); static int64_t stack_len(); + static const c10::optional> get_mode( + TorchDispatchModeKey mode_key); + static const c10::optional> unset_mode( + TorchDispatchModeKey mode_key); + static void set_mode( + const std::shared_ptr& mode, + TorchDispatchModeKey mode_key); + static const TorchDispatchModeTLS& get_state(); static void set_state(TorchDispatchModeTLS state); + static bool any_modes_set(bool skip_infra_modes = false); + private: std::vector> stack_; + // Users are allowed to push multiple ProxyTorchDispatchMode objects onto the + // stack + // However, we only allow a single FakeTensorMode onto the stack at a time + // (Pushing additional FakeTensorModes onto the stack is a no-op) + std::array< + c10::optional>, + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS)> + infra_modes_; }; C10_API bool dispatch_mode_enabled(); +C10_API std::string to_string(TorchDispatchModeKey mode_key); + } // namespace impl } // namespace c10 diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 9a66b0819ed8e0..5dd19d3cd2f48a 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -1515,6 +1515,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): "test_view_clone_view_inplace", "test_view_inplace", ]) +@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well") class TestCrossRefFunctionalization(TestFunctionalization): crossref = True diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 8f40e3e79a0de1..22df7ae669b0e8 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -13,6 +13,7 @@ from torch.utils._mode_utils import no_dispatch, all_same_mode from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \ log_input, capture_logs, capture_logs_with_logging_tensor_mode +from torch.testing._internal.two_tensor import TwoTensor from torch.utils._pytree import tree_map, tree_map_only from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack from torch._custom_op.functional import register_functional_op @@ -562,7 +563,6 @@ def test_register_fallthrough(self): # default behavior should have been restored self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16) - class TestPythonDispatch(TestCase): def test_basic(self) -> None: with capture_logs() as logs: @@ -894,6 +894,36 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor) self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor) + def test_make_fx_with_subclass(self) -> None: + def f(x, y): + # Returns (TwoTensor, Tensor) + return x * y, y + y + x_a = torch.zeros(4) + x_b = torch.zeros(4) + y = torch.ones(4) + + # make_fx() is not responsible for unwrapping tensor subclass inputs, + # so we do it manually here. + # Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling + # convention as f(*args). Unwrapping tensor subclass inputs can potentially change + # the number of input args to the graph, breaking that assumption + def f_to_trace(x_a, x_b, y): + x = TwoTensor(x_a, x_b) + out1, out2 = f(x, y) + out1_unwrapped_attrs, _ = out1.__tensor_flatten__() + return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2) + fx_g = make_fx(f_to_trace, tracing_mode='fake')(x_a, x_b, y) + self.assertExpectedInline(fx_g.code, """\ + + + +def forward(self, x_a_1, x_b_1, y_1): + mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None + mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None + add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None + return (mul, mul_1, add) + """) + def test_make_wrapper_subclass_propagates_metadata(self) -> None: class WrapperTensor(torch.Tensor): elem: torch.Tensor diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index e725ded05da8bf..eb8f5da2e14f70 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -10,7 +10,7 @@ ) from torchgen.gen import parse_native_yaml, parse_tags_yaml -from torchgen.model import DispatchKey, Variant +from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant from torchgen.utils import FileManager from tools.autograd.gen_python_functions import ( @@ -1227,6 +1227,9 @@ def replace_special_case(hint: str) -> str: # Dispatch key hints # ~~~~~~~~~~~~~~~~~~ dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey] + torch_dispatch_mode_key_hints = [ + f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey + ] # Tags Enum type hints # ~~~~~~~~~~~~~~~~~~~~ @@ -1247,6 +1250,7 @@ def replace_special_case(hint: str) -> str: "legacy_storage_base_hints": legacy_storage_base_hints, "dtype_class_hints": dtype_class_hints, "dispatch_key_hints": dispatch_key_hints, + "torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints, "all_directive": all_directive, "tag_attributes": tag_attributes, } diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e7073296bb81e5..9987e7a52da8f6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1181,7 +1181,8 @@ def _get_function_stack_at(idx: _int) -> Any: ... def _len_torch_function_stack() -> _int: ... def _set_torch_dispatch_mode(cls: Any) -> None: ... def _push_on_torch_dispatch_stack(cls: Any) -> None: ... -def _pop_torch_dispatch_stack() -> Any: ... +def _pop_torch_dispatch_stack(mode_key: Optional[torch._C._TorchDispatchModeKey] = None) -> Any: ... +def _get_dispatch_mode(mode_key: Optional[torch._C._TorchDispatchModeKey]) -> Any: ... def _get_dispatch_stack_at(idx: _int) -> Any: ... def _len_torch_dispatch_stack() -> _int: ... @@ -1432,6 +1433,9 @@ def _are_functorch_transforms_active() -> _bool: ... # Define in torch/csrc/autograd/init.cpp def _set_python_dispatcher(dispatcher: object) -> None: ... +class _TorchDispatchModeKey(Enum): + ${torch_dispatch_mode_key_hints} + # Defined in torch/csrc/utils/init.cpp class BenchmarkConfig: num_calling_threads: _int diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 428f4dc6773799..338fc6685fc678 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -13,6 +13,7 @@ import torch import torch._custom_op import torch._logging + from torch._guards import Source from torch._ops import OpOverload from torch._prims_common import ( @@ -34,7 +35,6 @@ from torch.overrides import TorchFunctionMode from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import ( - _get_current_dispatch_mode_stack, is_traceable_wrapper_subclass, TorchDispatchMode, ) @@ -156,6 +156,16 @@ def contains_tensor_types(type): ) +@contextlib.contextmanager +def unset_fake_temporarily(): + old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + try: + yield old + finally: + if old is not None: + torch._C._set_dispatch_mode(old) + + @functools.lru_cache(None) def _is_tensor_constructor(func: OpOverload): assert isinstance(func, OpOverload) @@ -970,6 +980,10 @@ class FakeTensor(torch.Tensor): _nonzero_memo: Optional[torch.SymInt] _nonzero_memo_vc: Optional[int] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FAKE + @property def nonzero_memo(self): if self._nonzero_memo is None: @@ -1110,15 +1124,14 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # unluckily attempted to hit FakeTensor's dispatch first, # NotImplemented lets us keep chaining until we find the actual # subclass - cur_stack = _get_current_dispatch_mode_stack() - # NB: This must test for ANY fake tensor mode, because we want the - # fake tensor mode to take precedence - fake_modes_on_stack = [m for m in cur_stack if isinstance(m, FakeTensorMode)] - if fake_modes_on_stack: + maybe_cur_fake_mode = torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FAKE + ) + if maybe_cur_fake_mode: not_implemented_log.debug( "FakeTensor mode already active: %s in %s", - fake_modes_on_stack, - cur_stack, + fake_mode, + maybe_cur_fake_mode, ) return NotImplemented @@ -1235,12 +1248,18 @@ def __init__( # True if we enter'ed and actually enabled fake tensor mode, # false if it was a no-op. Not thread safe but neither is # in_kernel_invocation - self.enter_stack: List[bool] = [] + # If another fake mode was already active when we enter, we also stash it here. + # That way when we exit, we know to re-enable the previous fake mode. + self.enter_stack: List[Tuple[bool, Optional[FakeTensorMode]]] = [] self.shape_env = shape_env self.stack = "".join(traceback.format_stack()) + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FAKE + # Typically, there is only one fake tensor mode and you test for it by # doing an isinstance test. However, in some situations, there might be # TWO fake tensor modes. The canonical example of this is exporting @@ -1258,27 +1277,35 @@ def is_our_fake(self, t): @count def __torch_dispatch__(self, func, types, args=(), kwargs=None): - assert self not in _get_current_dispatch_mode_stack(), func + # FakeTensorMode should not be set when we're inside of it. + assert ( + torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None + ), func try: return self.dispatch(func, types, args, kwargs) except TypeError: log.exception("fake tensor raised TypeError") raise - # No-op if FakeTensorMode is already on the stack + # No-op if FakeTensorMode is already in use def __enter__(self): - if self not in _get_current_dispatch_mode_stack(): - self.enter_stack.append(True) + maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key) + if self is not maybe_prev_fake_mode: + self.enter_stack.append((True, maybe_prev_fake_mode)) return super().__enter__() else: - # no-op - self.enter_stack.append(False) - return self + # no-op (still need to re-set the fake mode though since we unset it) + torch._C._set_dispatch_mode(self) + self.enter_stack.append((False, None)) + return self def __exit__(self, a, b, c): - live = self.enter_stack.pop() + live, maybe_prev_fake_mode = self.enter_stack.pop() if live: - return super().__exit__(a, b, c) + out = super().__exit__(a, b, c) + # Re-enable the previous fake mode, if there was one. + if maybe_prev_fake_mode is not None: + torch._C._set_dispatch_mode(maybe_prev_fake_mode) def dispatch(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 746695eb714e7a..019fb21e5e6c47 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -948,9 +948,22 @@ static PyObject* push_on_torch_dispatch_stack( PyObject* arg) { HANDLE_TH_ERRORS if (arg != Py_None) { + using c10::impl::TorchDispatchModeKey; + // When we push a mode onto the mode stack, we need to + // check if it's an "infra" mode, by checking its _mode_key attribute. + c10::optional mode_key = c10::nullopt; + py::object maybe_mode_key_obj = + PyObject_FastGetAttrString(arg, "_mode_key"); + if (maybe_mode_key_obj) { + mode_key = py::cast(maybe_mode_key_obj); + c10::impl::TorchDispatchModeTLS::set_mode( + std::make_shared(arg, getPyInterpreter()), + mode_key.value()); + } else { + c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack( + std::make_shared(arg, getPyInterpreter())); + } Py_INCREF(arg); - c10::impl::TorchDispatchModeTLS::push_onto_stack( - std::make_shared(arg, getPyInterpreter())); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -958,10 +971,25 @@ static PyObject* push_on_torch_dispatch_stack( static PyObject* pop_torch_dispatch_stack( PyObject* _unused, - PyObject* _unused2) { + PyObject* maybe_mode_key) { HANDLE_TH_ERRORS - const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack(); - auto* r = mode->ptr(getPyInterpreter()); + c10::optional mode_key = c10::nullopt; + PyObject* r; + if (maybe_mode_key != Py_None) { + mode_key = py::cast(maybe_mode_key); + auto maybe_mode = + c10::impl::TorchDispatchModeTLS::unset_mode(mode_key.value()); + TORCH_CHECK( + maybe_mode.has_value(), + "Attempted to unset ", + c10::impl::to_string(mode_key.value()), + ", but there wasn't one active."); + auto mode = maybe_mode.value(); + r = mode->ptr(getPyInterpreter()); + } else { + auto mode = c10::impl::TorchDispatchModeTLS::pop_stack(); + r = mode->ptr(getPyInterpreter()); + } Py_INCREF(r); return r; END_HANDLE_TH_ERRORS @@ -985,9 +1013,55 @@ static PyObject* get_dispatch_stack_at( END_HANDLE_TH_ERRORS } -static PyObject* len_torch_dispatch_stack( - PyObject* _unused, - PyObject* _unused2) { +static PyObject* set_dispatch_mode(PyObject* _unused, PyObject* mode) { + HANDLE_TH_ERRORS + TORCH_CHECK(mode != Py_None); + + py::object maybe_mode_key_obj = PyObject_FastGetAttrString(mode, "_mode_key"); + TORCH_CHECK( + maybe_mode_key_obj, + "set_dispatch_mode() called with a mode that does not contain a _mode_key attribute!"); + auto mode_key = py::cast(maybe_mode_key_obj); + + Py_INCREF(mode); + c10::impl::TorchDispatchModeTLS::set_mode( + std::make_shared(mode, getPyInterpreter()), mode_key); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(arg != Py_None); + auto mode_key = py::cast(arg); + + auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(mode_key); + if (maybe_mode == c10::nullopt) { + Py_RETURN_NONE; + } + auto* r = maybe_mode.value()->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(arg != Py_None); + auto mode_key = py::cast(arg); + + const auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key); + if (maybe_mode == c10::nullopt) { + Py_RETURN_NONE; + } + auto* r = maybe_mode.value()->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* len_torch_dispatch_stack(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS const auto len = c10::impl::TorchDispatchModeTLS::stack_len(); return utils::wrap(static_cast(len)); @@ -1099,10 +1173,7 @@ static PyMethodDef methods[] = { // NOLINT push_on_torch_dispatch_stack, METH_O, nullptr}, - {"_pop_torch_dispatch_stack", - pop_torch_dispatch_stack, - METH_NOARGS, - nullptr}, + {"_pop_torch_dispatch_stack", pop_torch_dispatch_stack, METH_O, nullptr}, {"_get_dispatch_stack_at", castPyCFunctionWithKeywords(get_dispatch_stack_at), METH_VARARGS | METH_KEYWORDS, @@ -1111,6 +1182,10 @@ static PyMethodDef methods[] = { // NOLINT len_torch_dispatch_stack, METH_NOARGS, nullptr}, + {"_set_dispatch_mode", set_dispatch_mode, METH_O, nullptr}, + {"_get_dispatch_mode", get_dispatch_mode, METH_O, nullptr}, + {"_unset_dispatch_mode", unset_dispatch_mode, METH_O, nullptr}, + {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index ee64a1d8879f1e..f013f62bea2d65 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -258,6 +258,129 @@ static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) { return (PyObject*)Py_TYPE(obj_or_type); } +static py::object dispatch_on_subclass( + PyObject* args, + PyObject* kwargs, + at::ArrayRef overloaded_args, + py::tuple py_types, + PyObject* torch_api_function, + bool is_torch_function, + const char* torch_function_name_str, + c10::optional maybe_mode_key = + c10::nullopt) { + py::object ret; + for (auto& arg : overloaded_args) { + py::object torch_function = + PyObject_FastGetAttrString(arg, torch_function_name_str); + if (!torch_function) { + TORCH_INTERNAL_ASSERT(0); + } + if (torch_function.ptr() == torch::disabled_torch_dispatch_impl()) { + // During __torch_dispatch__, don't dispatch on args with a disabled + // torch_dispatch. This code runs before infra modes, so we need to make + // sure that infra modes can run first. (In theory, maybe we can rearrange + // things so that infra modes are *always* attempted first, and just + // return NotImplemented when there are any user subclasses. Maybe that + // would fix this problem?) + continue; + } + + // See https://github.com/pytorch/pytorch/issues/63767 + if (is_torch_function && + PyObject_FastGetAttrString(torch_function.ptr(), "__self__") + .is(py::handle(arg)) && + torch_function.ptr() != torch::disabled_torch_function_impl()) { + TORCH_WARN( + "Defining your `", + torch_function_name_str, + "` as a plain method is deprecated ", + "and will be an error in future, please define it as a classmethod."); + } + + ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( + torch_function.ptr(), + torch_api_function, + py_types.ptr(), + args, + kwargs, + NULL)); + if (ret.ptr() == nullptr) { + throw python_error(); + } + if (ret.ptr() != Py_NotImplemented) { + // Return the reference to the result. This also covers the case where + // ret is NULL and __torch_function__/__torch_dispatch raised an + // exception, which we throw below + break; + } + } + return ret; +} + +static std::tuple dispatch_on_mode( + PyObject* args, + PyObject* kwargs, + py::tuple py_types, + PyObject* torch_api_function, + bool is_torch_function, + const char* torch_function_name_str) { + // Disable mode on the inside; this makes for a more user-friendly + // experience if you try to, e.g., print your tensors. + at::optional tf_g; + at::optional td_g; + py::object mode_obj; + // NB: We only really need keep the mode_obj live if the function call + // fails for error reporting, but whatever, Python refcounts are cheap + if (is_torch_function) { + tf_g.emplace(); + mode_obj = py::reinterpret_borrow( + tf_g->get_cur_mode()->ptr(getPyInterpreter())); + } else { + td_g.emplace(); + mode_obj = py::reinterpret_borrow( + td_g->get_cur_mode()->ptr(getPyInterpreter())); + } + py::object torch_function = + PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str); + if (!torch_function) { + TORCH_INTERNAL_ASSERT(0); + } + TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr); + TORCH_INTERNAL_ASSERT(args != nullptr); + + TORCH_CHECK( + PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(mode_obj), + "Defining your mode's `", + torch_function_name_str, + "` as a classmethod is not supported, please make it a plain method"); + + // Blegh. This accidentally works in PyObject_CallFunctionObjArgs below + // because the nullptr terminates the argument list ick ick ick. + py::object ret; + if (kwargs == nullptr) { + ret = py::reinterpret_steal(PyObject_CallMethod( + mode_obj.ptr(), + torch_function_name_str, + "OOO", + torch_api_function, + py_types.ptr(), + args)); + } else { + ret = py::reinterpret_steal(PyObject_CallMethod( + mode_obj.ptr(), + torch_function_name_str, + "OOOO", + torch_api_function, + py_types.ptr(), + args, + kwargs)); + } + if (ret.ptr() == nullptr) { + throw python_error(); + } + return std::make_tuple(ret, mode_obj); +} + // See Note: [Overloaded args] for what they hold auto handle_torch_function_no_python_arg_parser( at::ArrayRef overloaded_args, @@ -291,102 +414,89 @@ auto handle_torch_function_no_python_arg_parser( py::object ret; py::object mode_obj; + // Step 1: Try to dispatch based on the mode stack, *ignoring* infra + // torch_dispatch modes. const bool is_torch_function = torch_function_name == TorchFunctionName::TorchFunction; const auto is_mode_active = [&]() { - return is_torch_function ? at::impl::torch_function_mode_enabled() - : c10::impl::dispatch_mode_enabled(); + return is_torch_function + ? at::impl::torch_function_mode_enabled() + // Check if any *user* torch_dispatch modes are active (not including + // fake and proxy modes, which are special) + : c10::impl::dispatch_mode_enabled(); }; + // Note [__torch_dispatch__ dispatching order] + // The high-level idea motivating the dispatching + // order below is that: (1) modes get higher dispatch precedence over + // subclasses (2) "user" modes/subclasses get higher dispatch precedence over + // "infra" modes/subclasses. + // + // To give a complete example: let's say we are running torch.compile, with + // the following "user" modes and subclasses: + // mode_stack: [ModeA] + // user_args: [MyWrapperSubclassB(torchTensor)] + + // During tracing in AOTAutograd tracing, we use some additional infra modes + // and subclasses to perform tracing: + // FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode, + // FunctionalTensor, FakeTensor + // The modified mode stack and tracing arguments will look like this: + // mode_stack (user modes): [ModeA] + // mode_stack (infra modes): [ + // FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode + // ] + // tracing_args: [ + // MyWrapperSubclassB(FunctionalTensor(_to_functional_tensor(FakeTensor))) + // ] + + // And the dispatching order that we want is as follows: + // (1) ModeA.__torch_dispatch__ (user modes highest) + // (2) MyWrapperSubclassB.__torch_dispatch__ (user subclasses next highest) + // (3) FunctionalTensorMode.__torch_dispatch__ (infra modes next highest) + // (4) ProxyTorchDispatchMode.__torch_dispatch__ (infra modes next highest) + // (5) FakeTensorMode.__torch_dispatch__ (infra modes next highest) + // (6) FakeTensor.__torch_fake_dispatch__ (infra subclasses next highest) + + // Why does do FunctionalTensor and FakeTensor even need to be special-cased + // in the ordering? + // In theory we could remove their __torch_dispatch__, but both of these + // subclasses override sizes/strides metadata calls with __torch_dispatch__, + // which would mean a mode would be **required** to access their metadata. if (is_mode_active()) { - // Disable mode on the inside; this makes for a more user-friendly - // experience if you try to, e.g., print your tensors. - at::optional tf_g; - at::optional td_g; - // NB: We only really need keep the mode_obj live if the function call - // fails for error reporting, but whatever, Python refcounts are cheap - if (is_torch_function) { - tf_g.emplace(); - mode_obj = py::reinterpret_borrow( - tf_g->get_cur_mode()->ptr(getPyInterpreter())); - } else { - td_g.emplace(); - mode_obj = py::reinterpret_borrow( - td_g->get_cur_mode()->ptr(getPyInterpreter())); - } - py::object torch_function = - PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str); - if (!torch_function) { - TORCH_INTERNAL_ASSERT(0); - } - TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr); - TORCH_INTERNAL_ASSERT(args != nullptr); - - TORCH_CHECK( - PyObject_FastGetAttrString(torch_function.ptr(), "__self__") - .is(mode_obj), - "Defining your mode's `", - torch_function_name_str, - "` as a classmethod is not supported, please make it a plain method"); - - // Blegh. This accidentally works in PyObject_CallFunctionObjArgs below - // because the nullptr terminates the argument list ick ick ick. - if (kwargs == nullptr) { - ret = py::reinterpret_steal(PyObject_CallMethod( - mode_obj.ptr(), - torch_function_name_str, - "OOO", - torch_api_function, - py_types.ptr(), - args)); - } else { - ret = py::reinterpret_steal(PyObject_CallMethod( - mode_obj.ptr(), - torch_function_name_str, - "OOOO", - torch_api_function, - py_types.ptr(), - args, - kwargs)); - } - if (ret.ptr() == nullptr) { - throw python_error(); - } - } + // Step 1: Try to dispatch on any user TorchDispatchModes (including infra + // modes, which will always be at the bottom of the mode stack). + auto ret_ = dispatch_on_mode( + args, + kwargs, + py_types, + torch_api_function, + is_torch_function, + torch_function_name_str); + ret = std::get<0>(ret_); + mode_obj = std::get<1>(ret_); + } + + // Step 2: Try to dispatch based on any user subclasses, + // ignoring any subclasses that have a _mode_key field + // (corresponding to infra subclasses) + // Note: user subclasses should always run *before* infra modes like + // proxy/fake. This is handles by having proxy/fake modes return + // NotImplemented when they see a user subclass that they don't understand. if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) { - for (auto& arg : overloaded_args) { - py::object torch_function = - PyObject_FastGetAttrString(arg, torch_function_name_str); - if (!torch_function) { - TORCH_INTERNAL_ASSERT(0); - } - - // See https://github.com/pytorch/pytorch/issues/63767 - if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__") - .is(py::handle(arg)) && - torch_function.ptr() != torch::disabled_torch_function_impl()) { - TORCH_WARN( - "Defining your `", - torch_function_name_str, - "` as a plain method is deprecated ", - "and will be an error in future, please define it as a classmethod."); - } - - ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( - torch_function.ptr(), - torch_api_function, - py_types.ptr(), - args, - kwargs, - NULL)); - if (ret.ptr() != Py_NotImplemented) { - // Return the reference to the result. This also covers the case where - // ret is NULL and __torch_function__/__torch_dispatch raised an - // exception, which we throw below - break; - } + auto curr_ret = dispatch_on_subclass( + args, + kwargs, + overloaded_args, + py_types, + torch_api_function, + is_torch_function, + torch_function_name_str); + if (curr_ret.ptr() != nullptr) { + ret = curr_ret; } } + if (ret.ptr() == nullptr) { // if an exception occurred in a user's implementation of // __torch_function__, throw it diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 70629e59edd089..3216f1ff7246d6 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -744,6 +744,11 @@ void initDispatchBindings(PyObject* module) { return c10::SymInt( c10::SymNode(c10::make_intrusive(data))); }); + + using c10::impl::TorchDispatchModeKey; + py::enum_(m, "_TorchDispatchModeKey") + .value("PROXY", TorchDispatchModeKey::PROXY) + .value("FAKE", TorchDispatchModeKey::FAKE); } // TODO: dedupe with the kernel diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index 1b36569bdf66cd..78a7dbb8167369 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -8,11 +8,25 @@ namespace torch_dispatch_mode { struct StashTorchDispatchModeGuard { public: StashTorchDispatchModeGuard() { - saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack(); + if (c10::impl::TorchDispatchModeTLS::any_modes_set( + /*skip_infra_modes=*/true)) { + saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack(); + } else { + auto mode_and_key = + c10::impl::TorchDispatchModeTLS::pop_highest_infra_mode(); + saved_mode_ = std::move(std::get<0>(mode_and_key)); + saved_mode_key_ = std::get<1>(mode_and_key); + } } ~StashTorchDispatchModeGuard() { - c10::impl::TorchDispatchModeTLS::push_onto_stack(std::move(saved_mode_)); + if (saved_mode_key_ != c10::nullopt) { + c10::impl::TorchDispatchModeTLS::set_mode( + std::move(saved_mode_), saved_mode_key_.value()); + } else { + c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack( + std::move(saved_mode_)); + } } const std::shared_ptr& get_cur_mode() { @@ -21,6 +35,7 @@ struct StashTorchDispatchModeGuard { private: std::shared_ptr saved_mode_; + c10::optional saved_mode_key_; }; struct StashTorchDispatchStackGuard { diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index a928bdc8a7ed5e..e818ba0019287d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -9,7 +9,7 @@ import torch import torch.utils._pytree as pytree from torch.fx import Tracer, GraphModule -from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch import torch.fx as fx from torch.fx.passes.shape_prop import _extract_tensor_metadata @@ -25,11 +25,10 @@ from torch.utils._python_dispatch import ( TorchDispatchMode, - _pop_mode_temporarily, - _get_current_dispatch_mode, + _pop_mode, + _push_mode, ) -from torch._subclasses import FakeTensor from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode from torch.fx import Proxy import torch.fx.traceback as fx_traceback @@ -115,7 +114,7 @@ def snapshot_fake(val): return val.detach() def extract_val(val): - if isinstance(val, FakeTensor): + if is_fake(val): return snapshot_fake(val) elif isinstance(val, py_sym_types): return val @@ -144,7 +143,7 @@ def extract_val(val): def set_meta(proxy, val): proxy.node.meta['val'] = extract_val(val) # Best effort tensor_meta setting; prefer using val! - if isinstance(val, FakeTensor): + if is_fake(val): proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) elif isinstance(val, torch.Tensor) and not val.is_sparse: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) @@ -211,11 +210,7 @@ def get_constant(idx): def maybe_disable_fake_tensor_mode(): # TODO: figure out if this API generally makes sense and bake it into the # library - mb_fake_mode = _get_current_dispatch_mode() - if isinstance(mb_fake_mode, FakeTensorMode): - return _pop_mode_temporarily() - else: - return nullcontext() + return unset_fake_temporarily() @dataclass @@ -294,7 +289,7 @@ def can_handle_tensor(x): # If any of the Tensor inputs are "real" (not FakeTensor), we may # incorrectly burn in constants by allowing this access. Raise # an error in this case - if pytree.tree_all_only(torch.Tensor, lambda t: not isinstance(t, FakeTensor), (args, kwargs)): + if pytree.tree_all_only(torch.Tensor, lambda t: not is_fake(t), (args, kwargs)): raise RuntimeError( f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " "It's likely that this is caused by data-dependent control flow or similar. " @@ -470,6 +465,25 @@ def dispatch_trace( return GraphModule(tracer.root, graph, name) +@contextlib.contextmanager +def _pop_proxy_mode_temporarily(dk): + # This is a shim around the existng per-dispatch-key-mode logic. + # I'll delete the per-dispatch-key-mode logic in a followup PR + if dk is not None: + # During pre_dispatch, pop off of the PreDispatch mode stack + old = _pop_mode(dk) + try: + yield old + finally: + _push_mode(old, dk) + else: + # During normal tracing, pop off of the dedicated proxy mode stack + old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + try: + yield old + finally: + torch._C._set_dispatch_mode(old) + def wrap_key(f, tensors, tracer, pre_dispatch: bool): flat_tensors, tensors_spec = pytree.tree_flatten(tensors) dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None @@ -478,7 +492,7 @@ def wrap_key(f, tensors, tracer, pre_dispatch: bool): def wrapped(*proxies): flat_proxies, proxies_spec = pytree.tree_flatten(proxies) assert len(flat_proxies) == len(flat_tensors) - with _pop_mode_temporarily(dk) as m: + with _pop_proxy_mode_temporarily(dk) as m: assert isinstance(m, ProxyTorchDispatchMode) track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) @@ -548,6 +562,13 @@ def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constan self.sym_mode = ProxySymDispatchMode(tracer) self.trace_state = {} self._managers = [] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.PROXY + # Every time we enter a mode, we maintain a stack telling us what the previous + # ProxyTorchDispatchMode state was (if there was any). + # This lets us properly reset the state on exit. + self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = [] @count def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -559,17 +580,27 @@ def __enter__(self): m = self.sym_mode.enable(True) self._managers.append(m) m.__enter__() + # Stash and store the previous proxy mode (there may or may not be one) + maybe_prev_proxy_mode = torch._C._unset_dispatch_mode(self._mode_key) + self.enter_stack.append(maybe_prev_proxy_mode) return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): m = self._managers.pop() # ...exit us first, then sym mode b = super().__exit__(exc_type, exc_value, traceback) + + # Re-enable the previous proxy mode, if there was one. + mb_previous_proxy_mode = self.enter_stack.pop() + if mb_previous_proxy_mode is not None: + torch._C._set_dispatch_mode(mb_previous_proxy_mode) + if not b: return m.__exit__(exc_type, exc_value, traceback) else: return m.__exit__(None, None, None) + def inner_torch_dispatch(self, func, types, args=(), kwargs=None): if not self.enable_tracing: return func(*args, **kwargs) @@ -805,7 +836,7 @@ def wrap_fake(x): # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is # thus irrelevant to any external functional trace. with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, proxy_function_mode, \ - sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True): + sym_mode, proxy_mode, disable_autocast_cache(): t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs)) # TODO: kind of a bad way to do it, should maybe figure out a better way @@ -821,28 +852,24 @@ def get_torch_dispatch_modes(): def get_innermost_proxy_mode(): - for m in reversed(torch.utils._python_dispatch._get_current_dispatch_mode_stack()): - if isinstance(m, ProxyTorchDispatchMode): - return m - return None + return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) @contextlib.contextmanager def disable_proxy_modes_tracing(enable_current=False): - modes = get_torch_dispatch_modes() - proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)] - if enable_current: - proxy_tensor_modes = proxy_tensor_modes[:-1] - olds = [(m.enable_tracing, m.sym_mode.enable_tracing) for m in proxy_tensor_modes] - for proxy_mode in proxy_tensor_modes: - proxy_mode.enable_tracing = False - proxy_mode.sym_mode.enable_tracing = False + # enable_current=True is now a no-op, since only one proxy mode + # can live on the stack at a time. + # We should kill this API in a future PR. + maybe_old = None + if not enable_current: + # Only one proxy_mode can be "active" at a time. + # So we simply remove our active mode. + maybe_old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) try: yield finally: - for proxy_mode, (old, old_sym) in zip(proxy_tensor_modes, olds): - proxy_mode.enable_tracing = old - proxy_mode.sym_mode.enable_tracing = old_sym + if maybe_old is not None: + torch._C._set_dispatch_mode(maybe_old) def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"): diff --git a/torch/testing/_internal/two_tensor.py b/torch/testing/_internal/two_tensor.py new file mode 100644 index 00000000000000..6add6b33d6f86c --- /dev/null +++ b/torch/testing/_internal/two_tensor.py @@ -0,0 +1,57 @@ +import torch + + +# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors. +class TwoTensor(torch.Tensor): + @staticmethod + def __new__(cls, a, b): + assert ( + a.device == b.device + and a.layout == b.layout + and a.requires_grad == b.requires_grad + and a.dtype == b.dtype + ) + # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape + shape = a.shape + kwargs = {} + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, a, b): + self.a = a + self.b = b + + def __repr__(self): + a_repr = repr(self.a) + b_repr = repr(self.b) + return f"TwoTensor({a_repr}, {b_repr})" + + def __tensor_flatten__(self): + return ["a", "b"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta): + assert meta is None + a, b = inner_tensors["a"], inner_tensors["b"] + return TwoTensor(a, b) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + assert any(isinstance(x, TwoTensor) for x in args) + assert any(isinstance(x, TwoTensor) for x in args) + args_a = [x.a if isinstance(x, TwoTensor) else x for x in args] + args_b = [x.b if isinstance(x, TwoTensor) else x for x in args] + out_a = func(*args_a, **kwargs) + out_b = func(*args_b, **kwargs) + assert type(out_a) == type(out_b) + if isinstance(out_a, torch.Tensor): + return TwoTensor(out_a, out_b) + # for aten ops that return non-tensors, just assume that + # our two inner tensors return the same value + assert out_a == out_b + return out_a diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 029c614aa2fc28..8052e07b4f03b8 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,5 +1,5 @@ import contextlib -from typing import Optional +from typing import Optional, Union import warnings import torch @@ -56,7 +56,12 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - _pop_mode(self.__dict__.get("_dispatch_key", None)) + mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None) + if mb_dk_or_mode_key is None: + # Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch) + # We should probably revisit this. + mb_dk_or_mode_key = self.__dict__.get("_mode_key", None) + _pop_mode(mb_dk_or_mode_key) @classmethod def push(cls, *args, **kwargs): @@ -66,7 +71,10 @@ def push(cls, *args, **kwargs): def _get_current_dispatch_mode(): stack_len = _len_torch_dispatch_stack() - return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None + # Return a user mode on the stack if there are any + if stack_len > 0: + return _get_dispatch_stack_at(stack_len - 1) + return None def _get_current_dispatch_mode_stack(): @@ -87,12 +95,14 @@ def _push_mode(mode, k: Optional[DispatchKey] = None): _push_on_torch_dispatch_stack(mode) -def _pop_mode(k: Optional[DispatchKey] = None): - if k is not None: - from torch._ops import pop_mode_for_key - return pop_mode_for_key(k) - else: - return _pop_torch_dispatch_stack() +def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None): + if k is None or isinstance(k, torch._C._TorchDispatchModeKey): + return _pop_torch_dispatch_stack(k) + from torch._ops import pop_mode_for_key + # per-dispatch-key-mode-stack do not currently handle "always running infra modes last". + # In practice this doesn't matter, since ProxyTorchDispatchMode is the only mode + # that we push onto these per-dispatch-key-mode-stacks. + return pop_mode_for_key(k) @contextlib.contextmanager @@ -108,6 +118,8 @@ def _pop_mode_temporarily(k: Optional[DispatchKey] = None): def _disable_current_modes(): mode_len = _len_torch_dispatch_stack() old_modes = [_pop_mode() for _ in range(mode_len)] + + # Manually disable proxy and fake modes, if any are active try: yield old_modes finally: @@ -121,7 +133,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = {} return func(*args, **kwargs) - def is_traceable_wrapper_subclass(t): """ Returns whether or not a tensor subclass that implements __torch_dispatch__ diff --git a/torchgen/model.py b/torchgen/model.py index 9f47ad4051a01b..543806d14e4b3c 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -205,6 +205,11 @@ def parse(value: str) -> "DispatchKey": raise AssertionError(f"unknown dispatch key {value}") +class _TorchDispatchModeKey(Enum): + FAKE = auto() + PROXY = auto() + + def codegen_per_backend_entries() -> str: r = [] for fk in FUNCTIONALITY_KEYS: