Skip to content

Commit

Permalink
[fake_tensor] Don't run fallback for fbgemm ops (pytorch#106210)
Browse files Browse the repository at this point in the history
Summary:
This diff also adds more warning messages around allowing a namespace into the
fallback. We need to grandfather in an operator to actually merge this diff.

Test Plan: - existing tests

Differential Revision: D47873841

Pull Request resolved: pytorch#106210
Approved by: https://github.com/eellison
  • Loading branch information
zou3519 authored and pytorchmergebot committed Jul 28, 2023
1 parent 505dd31 commit f3d165b
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,22 +1420,34 @@ def dispatch(self, func, types, args=(), kwargs=None):
if op_impl_out != NotImplemented:
return op_impl_out

def can_fallback(func: OpOverload):
def can_run_unsafe_fallback(func: OpOverload):
if not self.allow_fallback_kernels:
return False
# It's OK to try the fallback for built-in ops (e.g. aten, prims)
# because we control and test these but the fallback leads to unexpected behavior
# in user-defined custom ops
return func.namespace in {
#
# WARNING: DO NOT add any additional namespaces/operators here if they refer to operators
# outside of the pytorch/pytorch library! Any pre-existing things here
# are either in the pytorch/pytorch library or have been grandfathered in.
# The fallback does not always work and MAY CRASH and emit unreadable error messages
# so it should not be allowed by default.
allowed_namespaces = {
"debugprims",
"prims",
"aten",
"xla",
"vision",
"torchtext",
"torchaudio",
"fbgemm",
}
grandfathered_ops_FIXME = {
"fbgemm::gmm",
}
return (
func.namespace in allowed_namespaces
or func.name() in grandfathered_ops_FIXME
)

# run kernel registered to meta for func, which include
# python meta registrations, prims, decomps, and c++ meta fns (structured kernels)
Expand All @@ -1444,7 +1456,7 @@ def can_fallback(func: OpOverload):
r = func(*args, **kwargs)
except NotImplementedError as not_implemented_error:
# no meta kernel registered, fallback to kernel for the device
if has_symbolic_sizes or not can_fallback(func):
if has_symbolic_sizes or not can_run_unsafe_fallback(func):
raise UnsupportedOperatorException(func)
return run_fallback_kernel(self, func, args, kwargs, not_implemented_error)

Expand Down

0 comments on commit f3d165b

Please sign in to comment.