Skip to content

Commit

Permalink
better support for fakeifying and dynamoing through torch_dispatch su…
Browse files Browse the repository at this point in the history
…bclasses (with dynamic shapes) (pytorch#107415)

There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:

(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests

(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.

(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).

Pull Request resolved: pytorch#107415
Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Aug 29, 2023
1 parent 378ffde commit 5efd63b
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 54 deletions.
113 changes: 113 additions & 0 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,119 @@ def forward(self, l_x_):
# frame count and op count are incremented due to re-compilation
check_count_and_graph(3, 3, 3, expected_graph)

def test_wrapper_subclass_guards_on_inner_tensor(self):
# Holds an inner tensor, that has a distinct shape from the outer wrapper tensor.
# Also adds additional guards on the inner tensor's sizes.
# When the first input to an op has x.shape[0] > 5, we insert an extra add node.
class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor):
@staticmethod
def __new__(cls, inner):
# Double the outer-most dimension
outer_shape = (inner.shape[0] * 2,) + inner.shape[1:]
return torch.Tensor._make_wrapper_subclass(
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
cls,
outer_shape,
inner.stride(),
None,
None,
inner.dtype,
inner.layout,
inner.device,
False,
inner.requires_grad,
)

def __init__(self, inner):
self.inner_elem = inner

def __tensor_flatten__(self):
return ["inner_elem"], None

@staticmethod
def __tensor_unflatten__(inner_tensors, _):
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])

def __repr__(self):
return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})"

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

args_inner = torch.utils._pytree.tree_map_only(
DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args
)
out_inner = func(*args_inner, **kwargs)

# Add guards on the inner tensor's sizes
if args_inner[0].shape[0] > 3:
out_inner += 2

return DoubleSizeMaybeAddGeThreeTensor(out_inner)

lower_bound_str = None
upper_bound_str = None
curr_var_to_val = None
curr_var_to_sources = None

def backend(gm, args):
print(gm.code)
context = torch._guards.TracingContext.get()
val_to_guards = list(context.fake_mode.shape_env.var_to_guards.values())

# Grab info on sources and guards from the shapenv
nonlocal lower_bound_str
nonlocal upper_bound_str
nonlocal curr_var_to_val
nonlocal curr_var_to_sources

lower_bound_str = str(val_to_guards[0][0].expr)
upper_bound_str = str(val_to_guards[0][1].expr)
curr_var_to_val = {
str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
}
curr_var_to_sources = {
str(k): v[0].name()
for k, v in context.fake_mode.shape_env.var_to_sources.items()
}
return gm

@torch.compile(backend=backend)
def fn(x):
if x.shape[0] < 10:
return torch.mul(x, x)
else:
return torch.div(x, x)

inp = torch.ones(4, 4)

x = DoubleSizeMaybeAddGeThreeTensor(inp)
torch._dynamo.mark_dynamic(x, 0)
res = fn(x)
# During fakeifying, we end up allocating a separate symint
# for the outer and inner tensor (in this test, s0 is unused).
expected_var_to_val = {
"s0": 8,
"s1": 4,
}
expected_var_to_sources = {
"s0": "L['x'].size()[0]",
"s1": "L['x'].inner_elem.size()[0]",
}
# lower bound comes from code underneath torch_dispatch (operating on the inner tensor size)
expected_lower_bound = "s1 > 3"
# upper bound comes from user code (operating on the wrapper size)
expected_upper_bound = "2*s1 < 10"
self.assertEqual(curr_var_to_val, expected_var_to_val)
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
self.assertEqual(lower_bound_str, expected_lower_bound)
self.assertEqual(upper_bound_str, expected_upper_bound)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def clone_tensor(x):
def clone_input(x, *, dtype=None):
"""copy while preserving strides"""
# TODO: this is questionable
if isinstance(x, torch._subclasses.FakeTensor):
if is_fake(x):
# this func fails on fake tensors in __torch_dispatch__
return x

Expand Down
17 changes: 15 additions & 2 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ def _is_tensor_constructor(func: OpOverload):
def is_fake(x):
if isinstance(x, FakeTensor):
return True
elif is_traceable_wrapper_subclass(x):
flattened_tensors, _ = type(x).__tensor_flatten__(x)
if is_traceable_wrapper_subclass(x):
attrs, _ = type(x).__tensor_flatten__(x)
flattened_tensors = [getattr(x, attr) for attr in attrs]
# need to recurse because we could have nested subclasses
all_fake = all(is_fake(x) for x in flattened_tensors)
any_fake = any(is_fake(x) for x in flattened_tensors)
Expand All @@ -185,6 +186,18 @@ def is_fake(x):
return False


def maybe_get_fake_mode(t):
if isinstance(t, FakeTensor):
return t.fake_mode
if is_traceable_wrapper_subclass(t):
inner_tensors, _ = t.__tensor_flatten__()
modes = [maybe_get_fake_mode(x) for x in inner_tensors]
m = modes[0]
assert all(m is x for x in modes)
return m
return None


@functools.lru_cache(None)
def get_schema_info(func):
return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined]
Expand Down
89 changes: 59 additions & 30 deletions torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,15 @@ def meta_tensor(
if shape_env is not None:
maybe_suppress = shape_env.suppress_guards

def sym_sizes_strides_storage_offset(t):
def sym_sizes_strides_storage_offset(t, src):
if shape_env is not None:
return shape_env.create_symbolic_sizes_strides_storage_offset(
t,
source,
src,
# Assume that the set of dims that are dynamic are the same between
# the wrapper tensor and any inner tensors.
# We can revisit this if this assumption does not hold
# for any important subclasses later.
dynamic_dims=dynamic_dims,
constraint_dims=constraint_dims,
)
Expand Down Expand Up @@ -278,7 +282,7 @@ def sym_sizes_strides_storage_offset(t):
elif t.is_mkldnn:
is_leaf = safe_is_leaf(t)
sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
t
t, source
)
r = callback(
lambda: torch.empty_strided(
Expand Down Expand Up @@ -365,7 +369,7 @@ def is_c_of_r(complex_dtype, real_dtype):
sizes,
strides,
storage_offset,
) = sym_sizes_strides_storage_offset(t)
) = sym_sizes_strides_storage_offset(t, source)

if safe_is_leaf(t):
# Leaf views that track view metadata are created by
Expand Down Expand Up @@ -402,12 +406,51 @@ def is_c_of_r(complex_dtype, real_dtype):

else:
is_leaf = safe_is_leaf(t)
sizes, strides, storage_offset = sym_sizes_strides_storage_offset(t)
r = callback(
lambda: torch.empty_strided(
sizes, strides, dtype=t.dtype, device="meta"
)
sizes, strides, storage_offset = sym_sizes_strides_storage_offset(
t, source
)

def empty_create(inner_t, inner_src):
(
inner_sizes,
inner_strides,
inner_storage_offset,
) = sym_sizes_strides_storage_offset(inner_t, inner_src)
return torch.empty_strided(
inner_sizes,
inner_strides,
dtype=inner_t.dtype,
device="meta",
)

# If we have a subclass that desugars into dense tensors,
# perform our callback on each inner tensor.
if is_traceable_wrapper_subclass(t):
# Note: transform_subclass will use __tensor_unflatten__ to generate
# a fresh subclass wrapper, which is why sizes/strides are not passed in
# to the creation function here.
# We assume that if the inner tensors of the subclass are given symbolic sizes,
# their sizes will be used to construct the (symbolic) sizes of the wrapper tensor.
from torch._dynamo.source import AttrSource

r = transform_subclass(
t,
lambda attr, inner_t: callback(
lambda: empty_create(
inner_t,
AttrSource(source, attr),
)
),
)
else:
r = callback(
lambda: torch.empty_strided(
sizes,
strides,
dtype=t.dtype,
device="meta",
)
)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = t.requires_grad
Expand Down Expand Up @@ -463,12 +506,13 @@ def is_c_of_r(complex_dtype, real_dtype):
# test/dynamo/test_dynamic_shapes.py
maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
from torch._subclasses.fake_tensor import (
FakeTensor,
in_kernel_invocation_manager,
maybe_get_fake_mode,
)

if isinstance(r, FakeTensor):
maybe_fake_mgr = in_kernel_invocation_manager(r.fake_mode)
mb_fake_mode = maybe_get_fake_mode(r)
if mb_fake_mode is not None:
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
with maybe_fake_mgr, torch.no_grad():
r.set_(r_s, storage_offset, sizes, strides)

Expand Down Expand Up @@ -510,6 +554,7 @@ def __call__(
type(t) is torch.Tensor
or type(t) is torch.nn.Parameter
or (ignore_subclass and isinstance(t, torch.Tensor))
or is_traceable_wrapper_subclass(t)
or isinstance(t, FakeTensor)
):
if t.device.type != "xla" and any(
Expand Down Expand Up @@ -602,24 +647,8 @@ def __call__(
r._is_param = True
return r
elif torch.overrides.is_tensor_like(t):
if is_traceable_wrapper_subclass(t):
# convert traceable wrapper subclasses to meta by converting
# the underlying tensor to meta
out = transform_subclass(
t,
lambda t: self.meta_tensor(
t, shape_env=shape_env, callback=callback, source=source
),
)
return out
else:
# Blindly converting tensor subclasses to meta can cause
# unpredictable problems; e.g., FX tests will trace meta
# tensors into their trace / some subclasses don't correctly
# support meta. Trying to YOLO this is more trouble than it's
# worth.
self.miss += 1
return NotImplemented
self.miss += 1
return NotImplemented
else:
# non-Tensor types don't count as hit or miss
return t
Expand Down
10 changes: 8 additions & 2 deletions torch/csrc/autograd/python_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,14 @@ static PyObject* THPVariable_make_wrapper_subclass(
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
tracer::impl::NoTracerDispatchMode tracer_guard{};

// We shouldn't need storage
Storage storage{Storage::use_byte_size_t{}, 0, at::DataPtr{}};
// We use storages **only** to track aliasing of subclasses during tracing.
// The actual data pointers are not valid.
Storage storage{
Storage::use_byte_size_t{},
0,
at::DataPtr{},
/*allocator=*/c10::GetAllocator(c10::kMeta),
/*resizeable=*/true};

tensor = at::detail::make_tensor<TensorImpl>(
std::move(storage), options.computeDispatchKey(), options.dtype());
Expand Down
9 changes: 9 additions & 0 deletions torch/distributed/_functional_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ def __new__(cls, elem: torch.Tensor):
r.elem = elem
return r

def __tensor_flatten__(self):
return ["elem"], None

@staticmethod
def __tensor_unflatten__(inner_tensors, meta):
assert meta is None
elem = inner_tensors["elem"]
return AsyncCollectiveTensor(elem)

def __repr__(self):
wait_tensor(self.elem)
return f"AsyncCollectiveTensor({self.elem})"
Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,14 @@ def __tensor_flatten__(self):
protocol to inform how to flatten a DTensor to local tensor
for PT2 tracing
"""
return self._local_tensor, self._spec
return ["_local_tensor"], self._spec

@staticmethod
def __tensor_unflatten__(local_tensor, spec):
def __tensor_unflatten__(inner_tensors, spec):
assert (
spec is not None
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
local_tensor = inner_tensors["_local_tensor"]
return DTensor(
local_tensor,
spec.mesh,
Expand Down
Loading

0 comments on commit 5efd63b

Please sign in to comment.