Skip to content

Commit

Permalink
reorder proxy / fake modes so they always run last (pytorch#104482)
Browse files Browse the repository at this point in the history
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates:

(1) TLS changes in `TorchDispatchModeTLS.h/cpp`.

I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack).

`TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes.

`TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes".

`TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes).

There were also some mild codegen changes to support the new enum

(2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py`

The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`.

I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly.

(2) dispatching logic in `python_arg_parser.cpp`

This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now:
```
(a) dispatch_on_mode()  # try user modes first (where the mode stack automatically considers infra modes last)
(b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses)
(c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses)
```

Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass.

How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key.

(3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD  pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends).

----- original description below -----

The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit

Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that.

**(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode**

This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes.

I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set.

This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler.

**(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()**

This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs.

**(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's**

Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active.

**(4) Allow traceable tensor subclasses to access storages from python**
This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable"

**(5) Fixed subclass fakeification**

@wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct.

**(6) make tensor subclasses resizeable**

Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`.

**(7) Added a basic DoubleTensor subclass for testing**

I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR.

Pull Request resolved: pytorch#104482
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#107415
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Aug 29, 2023
1 parent 5efd63b commit da54f3c
Show file tree
Hide file tree
Showing 15 changed files with 703 additions and 178 deletions.
151 changes: 135 additions & 16 deletions c10/core/impl/TorchDispatchModeTLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/util/irange.h>

#include <utility>

Expand All @@ -10,8 +11,23 @@ namespace impl {

thread_local TorchDispatchModeTLS torchDispatchModeState;

void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
if (torchDispatchModeState.stack_.empty()) {
bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) {
if (!torchDispatchModeState.stack_.empty())
return true;
if (!skip_infra_modes) {
for (const auto i : c10::irange(
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
return true;
}
}
}
return false;
}

void TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
std::shared_ptr<SafePyObject> mode) {
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
Expand All @@ -20,30 +36,122 @@ void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
}

const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
TORCH_CHECK(
!torchDispatchModeState.stack_.empty(),
"trying to pop from empty mode stack");
std::shared_ptr<SafePyObject> out = torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();

if (torchDispatchModeState.stack_.empty()) {
std::shared_ptr<SafePyObject> out;
if (!torchDispatchModeState.stack_.empty()) {
out = torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
} else {
for (int64_t i =
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
out = std::move(torchDispatchModeState.infra_modes_[i].value());
torchDispatchModeState.infra_modes_[i] = c10::nullopt;
break;
}
}
}
TORCH_CHECK(out, "trying to pop from empty mode stack");
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
}
const std::tuple<std::shared_ptr<SafePyObject>, TorchDispatchModeKey>
TorchDispatchModeTLS::pop_highest_infra_mode() {
for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
auto out_mode = torchDispatchModeState.infra_modes_[i].value();
torchDispatchModeState.infra_modes_[i] = c10::nullopt;
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return std::make_tuple(
std::move(out_mode), static_cast<TorchDispatchModeKey>(i));
}
}
TORCH_CHECK(
false, "Called pop_highest_infra_mode, but no infra modes were active.")
}

const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
int64_t idx) {
TORCH_CHECK(
idx < static_cast<int64_t>(torchDispatchModeState.stack_.size()),
"Tried to get stack at idx that's too big");
return torchDispatchModeState.stack_[idx];
TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big");
// Our "logical" stack includes both:
// - any user modes (the entire torchDispatchModeState.stack_)
// - any infra modes (members of torchDispatchModeState.infra_modes_ that are
// not None)

// idx == 0 means the "bottom" of the stack, which starts with any infra
// modes (iterating from lowest-priority to highest-priority).
auto curr_idx = idx;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
if (curr_idx == 0) {
return torchDispatchModeState.infra_modes_[i].value();
}
curr_idx -= 1;
}
}
// At this point, we're guaranteed that curr_idx < stack_.size()
return torchDispatchModeState.stack_[curr_idx];
}

int64_t TorchDispatchModeTLS::stack_len() {
return static_cast<int64_t>(torchDispatchModeState.stack_.size());
auto stack_len = static_cast<int64_t>(torchDispatchModeState.stack_.size());
int64_t infra_modes_len = 0;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) {
infra_modes_len += 1;
}
}
return stack_len + infra_modes_len;
}

const c10::optional<std::shared_ptr<SafePyObject>> TorchDispatchModeTLS::
get_mode(TorchDispatchModeKey mode_key) {
return torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
}

void TorchDispatchModeTLS::set_mode(
const std::shared_ptr<SafePyObject>& mode,
TorchDispatchModeKey mode_key) {
TORCH_CHECK(
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] ==
c10::nullopt,
"trying to set the current ",
to_string(mode_key),
", but one already exists");

if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}

torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] = mode;
}

const c10::optional<std::shared_ptr<SafePyObject>> TorchDispatchModeTLS::
unset_mode(TorchDispatchModeKey mode_key) {
auto out = torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] =
c10::nullopt;
if (out.has_value() && !any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
}

const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
Expand All @@ -52,7 +160,7 @@ const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {

void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
torchDispatchModeState = std::move(state);
if (torchDispatchModeState.stack_.empty()) {
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
Expand All @@ -67,7 +175,18 @@ void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {

bool dispatch_mode_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
TorchDispatchModeTLS::stack_len() > 0;
TorchDispatchModeTLS::any_modes_set();
}

std::string to_string(TorchDispatchModeKey mode_key) {
switch (mode_key) {
case TorchDispatchModeKey::PROXY:
return "ProxyTorchDispatchMode";
case TorchDispatchModeKey::FAKE:
return "FakeTensorMode";
default:
return "UNKNOWN_MODE";
}
}

} // namespace impl
Expand Down
37 changes: 36 additions & 1 deletion c10/core/impl/TorchDispatchModeTLS.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,55 @@
namespace c10 {
namespace impl {

enum class TorchDispatchModeKey : int8_t { FAKE, PROXY, NUM_MODE_KEYS };

struct C10_API TorchDispatchModeTLS {
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
// This API is NOT invariant safe.
// It must not take in an infra mode that uses TorchDispatchModeKey
// If you're pushing an infra mode onto the stack, we expect
// you to use set_mode
static void push_non_infra_mode_onto_stack(
std::shared_ptr<SafePyObject> mode);
// Pops the top mode of the stack,
// giving precedence to user modes before attempting to pop
// any infra modes
static const std::shared_ptr<SafePyObject> pop_stack();
// Returns the highest-priority infra mode on the stack,
// along with its mode key.
static const std::tuple<std::shared_ptr<SafePyObject>, TorchDispatchModeKey>
pop_highest_infra_mode();

static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
static int64_t stack_len();

static const c10::optional<std::shared_ptr<SafePyObject>> get_mode(
TorchDispatchModeKey mode_key);
static const c10::optional<std::shared_ptr<SafePyObject>> unset_mode(
TorchDispatchModeKey mode_key);
static void set_mode(
const std::shared_ptr<SafePyObject>& mode,
TorchDispatchModeKey mode_key);

static const TorchDispatchModeTLS& get_state();
static void set_state(TorchDispatchModeTLS state);

static bool any_modes_set(bool skip_infra_modes = false);

private:
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
// Users are allowed to push multiple ProxyTorchDispatchMode objects onto the
// stack
// However, we only allow a single FakeTensorMode onto the stack at a time
// (Pushing additional FakeTensorModes onto the stack is a no-op)
std::array<
c10::optional<std::shared_ptr<c10::SafePyObject>>,
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS)>
infra_modes_;
};

C10_API bool dispatch_mode_enabled();

C10_API std::string to_string(TorchDispatchModeKey mode_key);

} // namespace impl
} // namespace c10
1 change: 1 addition & 0 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
"test_view_clone_view_inplace",
"test_view_inplace",
])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well")
class TestCrossRefFunctionalization(TestFunctionalization):
crossref = True

Expand Down
32 changes: 31 additions & 1 deletion test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.utils._mode_utils import no_dispatch, all_same_mode
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, capture_logs_with_logging_tensor_mode
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._pytree import tree_map, tree_map_only
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack
from torch._custom_op.functional import register_functional_op
Expand Down Expand Up @@ -562,7 +563,6 @@ def test_register_fallthrough(self):
# default behavior should have been restored
self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)


class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
with capture_logs() as logs:
Expand Down Expand Up @@ -894,6 +894,36 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor)
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)

def test_make_fx_with_subclass(self) -> None:
def f(x, y):
# Returns (TwoTensor, Tensor)
return x * y, y + y
x_a = torch.zeros(4)
x_b = torch.zeros(4)
y = torch.ones(4)

# make_fx() is not responsible for unwrapping tensor subclass inputs,
# so we do it manually here.
# Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
# convention as f(*args). Unwrapping tensor subclass inputs can potentially change
# the number of input args to the graph, breaking that assumption
def f_to_trace(x_a, x_b, y):
x = TwoTensor(x_a, x_b)
out1, out2 = f(x, y)
out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)
fx_g = make_fx(f_to_trace, tracing_mode='fake')(x_a, x_b, y)
self.assertExpectedInline(fx_g.code, """\
def forward(self, x_a_1, x_b_1, y_1):
mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None
mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None
add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None
return (mul, mul_1, add)
""")

def test_make_wrapper_subclass_propagates_metadata(self) -> None:
class WrapperTensor(torch.Tensor):
elem: torch.Tensor
Expand Down
6 changes: 5 additions & 1 deletion tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from torchgen.gen import parse_native_yaml, parse_tags_yaml

from torchgen.model import DispatchKey, Variant
from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
from torchgen.utils import FileManager

from tools.autograd.gen_python_functions import (
Expand Down Expand Up @@ -1227,6 +1227,9 @@ def replace_special_case(hint: str) -> str:
# Dispatch key hints
# ~~~~~~~~~~~~~~~~~~
dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey]
torch_dispatch_mode_key_hints = [
f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey
]

# Tags Enum type hints
# ~~~~~~~~~~~~~~~~~~~~
Expand All @@ -1247,6 +1250,7 @@ def replace_special_case(hint: str) -> str:
"legacy_storage_base_hints": legacy_storage_base_hints,
"dtype_class_hints": dtype_class_hints,
"dispatch_key_hints": dispatch_key_hints,
"torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints,
"all_directive": all_directive,
"tag_attributes": tag_attributes,
}
Expand Down
6 changes: 5 additions & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,8 @@ def _get_function_stack_at(idx: _int) -> Any: ...
def _len_torch_function_stack() -> _int: ...
def _set_torch_dispatch_mode(cls: Any) -> None: ...
def _push_on_torch_dispatch_stack(cls: Any) -> None: ...
def _pop_torch_dispatch_stack() -> Any: ...
def _pop_torch_dispatch_stack(mode_key: Optional[torch._C._TorchDispatchModeKey] = None) -> Any: ...
def _get_dispatch_mode(mode_key: Optional[torch._C._TorchDispatchModeKey]) -> Any: ...
def _get_dispatch_stack_at(idx: _int) -> Any: ...
def _len_torch_dispatch_stack() -> _int: ...

Expand Down Expand Up @@ -1432,6 +1433,9 @@ def _are_functorch_transforms_active() -> _bool: ...
# Define in torch/csrc/autograd/init.cpp
def _set_python_dispatcher(dispatcher: object) -> None: ...

class _TorchDispatchModeKey(Enum):
${torch_dispatch_mode_key_hints}

# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig:
num_calling_threads: _int
Expand Down
Loading

0 comments on commit da54f3c

Please sign in to comment.