From 05073004a6e16583f584a4c576f24a94dd7cd673 Mon Sep 17 00:00:00 2001 From: TYWZ22259 Date: Sat, 20 Jan 2024 12:03:38 +0000 Subject: [PATCH] !9083 AddRmsNorm support onnx Merge pull request !9083 from TYWZ22259/add_rms_norm_onnx --- test/onnx/test_wrapper_onnx_ops.py | 23 +++++++++++++++++++++++ test/test_fake_tensor.py | 19 +++++++++++++++++++ torch_npu/meta/meta_registrations.py | 13 +++++++++++++ torch_npu/onnx/wrapper_onnx_ops.py | 16 ++++++++++++++++ 4 files changed, 71 insertions(+) diff --git a/test/onnx/test_wrapper_onnx_ops.py b/test/onnx/test_wrapper_onnx_ops.py index 780db1baf..ab5fa5138 100644 --- a/test/onnx/test_wrapper_onnx_ops.py +++ b/test/onnx/test_wrapper_onnx_ops.py @@ -1250,6 +1250,29 @@ def export_onnx(onnx_model_name): assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `npu_add_rms_norm` is only supported on 910B, skip this ut!") + def test_wrapper_npu_add_rms_norm(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x1, x2, gamma): + epsilon = 1e-6 + x = torch_npu.npu_add_rms_norm(x1, x2, gamma, epsilon) + return x + + def export_onnx(onnx_model_name): + x1 = torch.rand(10, 1024).uniform_(-3, 3).npu().half() + x2 = torch.rand(10, 1024).uniform_(-3, 3).npu().half() + gamma = torch.rand(10).uniform_(-3, 3).npu().half() + model = Model().to("npu") + model(x1, x2, gamma) + self.onnx_export(model, (x1, x2, gamma), onnx_model_name) + onnx_model_name = "model_npu_add_rms_norm.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, + onnx_model_name))) + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `RotaryMul` is only supported on 910B, skip this ut!") def test_wrapper_npu_rotary_mul(self): class Model(torch.nn.Module): diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 76555a9d5..d60ce633c 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1448,6 +1448,25 @@ def test_npu_rms_norm_backward(self): self.assertEqual(dw.device, npu_gamma.device) +class TestNpuAddRmsNorm(TestCase): + def test_npu_add_rms_norm(self): + with FakeTensorMode(): + npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu() + npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu() + npu_gamma = torch.randn((3,), dtype=torch.float32).npu() + + result_y, result_rstd, result_x = torch_npu.npu_add_rms_norm(npu_x1, npu_x2, npu_gamma) + + self.assertEqual(result_y.dtype, npu_x1.dtype) + self.assertEqual(result_y.shape, npu_x1.shape) + self.assertEqual(result_y.device, npu_x1.device) + self.assertEqual(result_rstd.shape, torch.Size([2, 1])) + self.assertEqual(result_rstd.device, npu_x1.device) + self.assertEqual(result_x.dtype, npu_x1.dtype) + self.assertEqual(result_x.shape, npu_x1.shape) + self.assertEqual(result_x.device, npu_x1.device) + + class TestFFN(TestCase): def test_npu_ffn_meta(self): with FakeTensorMode(): diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 8c775116e..37c40accf 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -107,6 +107,19 @@ def npu_rms_norm_meta(self, gamma, epsilon=1e-6): return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(rstd)) +@impl(m, "npu_add_rms_norm") +def npu_add_rms_norm_meta(x1, x2, gamma, epsilon=1e-6): + rstd_dim = x1.dim() - gamma.dim() + ret = [] + for i in range(x1.dim()): + if i < rstd_dim: + ret.append(x1.size(i)) + else: + ret.append(1) + rstd = torch.empty(ret, dtype=torch.float32, device='meta') + return (torch.empty_like(x1, dtype=x1.dtype), torch.empty_like(rstd), torch.empty_like(x1, dtype=x1.dtype)) + + @impl(m, "npu_rms_norm_backward") def npu_rms_norm_backward_meta(dy, self, gamma, rstd): return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(gamma, dtype=gamma.dtype)) diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 508826fa9..5595ed857 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -194,6 +194,17 @@ def symbolic(g, self: Tensor, gamma: Tensor, epsilon: float = 1e-6): return g.op("npu::NPURmsNorm", self, gamma, epsilon_f=epsilon, outputs=2) +class NPUAddRmsNormOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_add_rms_norm(*args, **kwargs) + + @staticmethod + def symbolic(g, x1: Tensor, x2: Tensor, gamma: Tensor, epsilon: float = 1e-6): + return g.op("npu::NPURmsNorm", x1, x2, gamma, epsilon_f=epsilon, outputs=3) + + class NPUDiouOP(torch.autograd.Function): @staticmethod @@ -872,6 +883,10 @@ def wrapper_npu_rms_norm(self, gamma, epsilon=1e-6): return NPURmsNormOP.apply(self, gamma, epsilon) +def wrapper_npu_add_rms_norm(x1, x2, gamma, epsilon=1e-6): + return NPUAddRmsNormOP.apply(x1, x2, gamma, epsilon) + + def wrapper_npu_nms_v4(self, scores, max_output_size, iou_threshold, scores_threshold, pad_to_max_output_size=False): return NPUNmsV4OP.apply(self, scores, max_output_size, @@ -1039,6 +1054,7 @@ def add_onnx_ops(): torch_npu.npu_scatter = wrapper_npu_scatter torch_npu.npu_lstm = wrapper_npu_lstm torch_npu.npu_rms_norm = wrapper_npu_rms_norm + torch_npu.npu_add_rms_norm = wrapper_npu_add_rms_norm torch_npu.npu_lstm_cell = wrapper_npu_lstm_cell torch_npu.npu_gru = wrapper_npu_gru torch_npu.npu_dropout_with_add_softmax = wrapper_npu_dropout_with_add_softmax