From b410378d93bc57b98cc5ea56ceee8f95865b0ba5 Mon Sep 17 00:00:00 2001 From: Runming Lu Date: Wed, 15 Jan 2025 19:40:42 +0000 Subject: [PATCH] Register nonzero for meta device for FBLSim (#144727) Summary: Fix `nonzero is not registered to meta` issue: ``` "NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered". ``` Reviewed By: ezyang Differential Revision: D66525640 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144727 Approved by: https://github.com/ezyang --- test/test_meta.py | 11 +++++++++++ torch/_meta_registrations.py | 20 ++++++++++++++++++++ torch/fx/experimental/_config.py | 6 ++++++ 3 files changed, 37 insertions(+) diff --git a/test/test_meta.py b/test/test_meta.py index 61fabc513f5cc4..b2f322740b8fcc 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -13,6 +13,7 @@ import torch.utils._python_dispatch from torch._dispatch.python import enable_python_dispatcher from torch._ops import OpOverload, OpOverloadPacket +from torch.fx.experimental import _config as exp_config from torch.testing import make_tensor from torch.testing._internal.common_utils import unMarkDynamoStrictTest from torch.testing._internal.common_utils import ( @@ -1794,6 +1795,16 @@ def run(device): elif cpu_err is not None and meta_err is None: raise RuntimeError("cpu failed, but meta didn't.") from cpu_err + def test_nonzero(self): + t = torch.randn(2, 3, 4, device='meta') + with exp_config.patch(meta_nonzero_assume_all_nonzero=True): + nz = t.nonzero() + self.assertEqual(nz.dtype, torch.int64) + self.assertEqual(nz.device.type, 'meta') + self.assertEqual(nz.shape, torch.Size([24, 3])) + self.assertEqual(nz.stride(), torch.Size([1, 24])) + + instantiate_device_type_tests(TestMeta, globals()) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 53d96708fd3d9b..8108c4278802df 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -36,6 +36,7 @@ out_wrapper, ) from torch._refs import _broadcast_shapes, _maybe_broadcast +from torch.fx.experimental import _config as exp_config from torch.utils import _pytree as pytree @@ -3115,6 +3116,25 @@ def nonzero_static(self, *, size: int, fill_value: int = -1): return self.new_empty((size, self.dim()), dtype=torch.long) +@register_meta([torch.ops.aten.nonzero.default, torch.ops.aten.nonzero.out]) +@out_wrapper() +def nonzero(self): + torch._check_not_implemented( + exp_config.meta_nonzero_assume_all_nonzero, + lambda: "The register_meta function for torch.nonzero() raises unimplemented by default, " + "as a correct data-independent implementation does not exist. This implementation " + "returns a fake value, assuming all elements of the tensor are non-zero. " + "To enable this registration, please set " + "'torch.fx.experimental._config.meta_nonzero_assume_all_nonzero' to True.", + ) + return torch.empty_strided( + (self.numel(), self.dim()), + (1, self.numel()), + dtype=torch.long, + device=self.device, + ) + + @register_meta([aten.index.Tensor, aten._unsafe_index.Tensor]) def meta_index_Tensor(self, indices): torch._check(bool(indices), lambda: "at least one index must be provided") diff --git a/torch/fx/experimental/_config.py b/torch/fx/experimental/_config.py index f901919fb45c65..fe92494092113d 100644 --- a/torch/fx/experimental/_config.py +++ b/torch/fx/experimental/_config.py @@ -82,6 +82,12 @@ # This flag changes whether we should use the same symbolic variable to represent input sizes that are the same. use_duck_shape = True +# Controls the registration of torch.nonzero() on the meta device. +# When True, nonzero returns a tensor with shape (self.numel(), self.dim()) +# assuming all elements are none-zero. +# Default is False to prevent unintended registration. Set to True to enable. +meta_nonzero_assume_all_nonzero = False + from torch.utils._config_module import install_config_module