Skip to content

Commit

Permalink
Register nonzero for meta device for FBLSim (pytorch#144727)
Browse files Browse the repository at this point in the history
Summary:
Fix `nonzero is not registered to meta` issue:
```
"NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered".
```

Reviewed By: ezyang

Differential Revision: D66525640

Pull Request resolved: pytorch#144727
Approved by: https://github.com/ezyang
  • Loading branch information
lurunming authored and pytorchmergebot committed Jan 15, 2025
1 parent 834086c commit b410378
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.utils._python_dispatch
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import OpOverload, OpOverloadPacket
from torch.fx.experimental import _config as exp_config
from torch.testing import make_tensor
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -1794,6 +1795,16 @@ def run(device):
elif cpu_err is not None and meta_err is None:
raise RuntimeError("cpu failed, but meta didn't.") from cpu_err

def test_nonzero(self):
t = torch.randn(2, 3, 4, device='meta')
with exp_config.patch(meta_nonzero_assume_all_nonzero=True):
nz = t.nonzero()
self.assertEqual(nz.dtype, torch.int64)
self.assertEqual(nz.device.type, 'meta')
self.assertEqual(nz.shape, torch.Size([24, 3]))
self.assertEqual(nz.stride(), torch.Size([1, 24]))



instantiate_device_type_tests(TestMeta, globals())

Expand Down
20 changes: 20 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
out_wrapper,
)
from torch._refs import _broadcast_shapes, _maybe_broadcast
from torch.fx.experimental import _config as exp_config
from torch.utils import _pytree as pytree


Expand Down Expand Up @@ -3115,6 +3116,25 @@ def nonzero_static(self, *, size: int, fill_value: int = -1):
return self.new_empty((size, self.dim()), dtype=torch.long)


@register_meta([torch.ops.aten.nonzero.default, torch.ops.aten.nonzero.out])
@out_wrapper()
def nonzero(self):
torch._check_not_implemented(
exp_config.meta_nonzero_assume_all_nonzero,
lambda: "The register_meta function for torch.nonzero() raises unimplemented by default, "
"as a correct data-independent implementation does not exist. This implementation "
"returns a fake value, assuming all elements of the tensor are non-zero. "
"To enable this registration, please set "
"'torch.fx.experimental._config.meta_nonzero_assume_all_nonzero' to True.",
)
return torch.empty_strided(
(self.numel(), self.dim()),
(1, self.numel()),
dtype=torch.long,
device=self.device,
)


@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
def meta_index_Tensor(self, indices):
torch._check(bool(indices), lambda: "at least one index must be provided")
Expand Down
6 changes: 6 additions & 0 deletions torch/fx/experimental/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@
# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.
use_duck_shape = True

# Controls the registration of torch.nonzero() on the meta device.
# When True, nonzero returns a tensor with shape (self.numel(), self.dim())
# assuming all elements are none-zero.
# Default is False to prevent unintended registration. Set to True to enable.
meta_nonzero_assume_all_nonzero = False

from torch.utils._config_module import install_config_module


Expand Down

0 comments on commit b410378

Please sign in to comment.