Skip to content

Commit

Permalink
!9083 AddRmsNorm support onnx
Browse files Browse the repository at this point in the history
Merge pull request !9083 from TYWZ22259/add_rms_norm_onnx
  • Loading branch information
TYWZ22259 authored and it-is-a-robot committed Jan 20, 2024
1 parent cd0580f commit 0507300
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 0 deletions.
23 changes: 23 additions & 0 deletions test/onnx/test_wrapper_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,29 @@ def export_onnx(onnx_model_name):
assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path,
onnx_model_name)))

@unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `npu_add_rms_norm` is only supported on 910B, skip this ut!")
def test_wrapper_npu_add_rms_norm(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x1, x2, gamma):
epsilon = 1e-6
x = torch_npu.npu_add_rms_norm(x1, x2, gamma, epsilon)
return x

def export_onnx(onnx_model_name):
x1 = torch.rand(10, 1024).uniform_(-3, 3).npu().half()
x2 = torch.rand(10, 1024).uniform_(-3, 3).npu().half()
gamma = torch.rand(10).uniform_(-3, 3).npu().half()
model = Model().to("npu")
model(x1, x2, gamma)
self.onnx_export(model, (x1, x2, gamma), onnx_model_name)
onnx_model_name = "model_npu_add_rms_norm.onnx"
export_onnx(onnx_model_name)
assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path,
onnx_model_name)))

@unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `RotaryMul` is only supported on 910B, skip this ut!")
def test_wrapper_npu_rotary_mul(self):
class Model(torch.nn.Module):
Expand Down
19 changes: 19 additions & 0 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,25 @@ def test_npu_rms_norm_backward(self):
self.assertEqual(dw.device, npu_gamma.device)


class TestNpuAddRmsNorm(TestCase):
def test_npu_add_rms_norm(self):
with FakeTensorMode():
npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()

result_y, result_rstd, result_x = torch_npu.npu_add_rms_norm(npu_x1, npu_x2, npu_gamma)

self.assertEqual(result_y.dtype, npu_x1.dtype)
self.assertEqual(result_y.shape, npu_x1.shape)
self.assertEqual(result_y.device, npu_x1.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x1.device)
self.assertEqual(result_x.dtype, npu_x1.dtype)
self.assertEqual(result_x.shape, npu_x1.shape)
self.assertEqual(result_x.device, npu_x1.device)


class TestFFN(TestCase):
def test_npu_ffn_meta(self):
with FakeTensorMode():
Expand Down
13 changes: 13 additions & 0 deletions torch_npu/meta/meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ def npu_rms_norm_meta(self, gamma, epsilon=1e-6):
return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(rstd))


@impl(m, "npu_add_rms_norm")
def npu_add_rms_norm_meta(x1, x2, gamma, epsilon=1e-6):
rstd_dim = x1.dim() - gamma.dim()
ret = []
for i in range(x1.dim()):
if i < rstd_dim:
ret.append(x1.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(x1, dtype=x1.dtype), torch.empty_like(rstd), torch.empty_like(x1, dtype=x1.dtype))


@impl(m, "npu_rms_norm_backward")
def npu_rms_norm_backward_meta(dy, self, gamma, rstd):
return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(gamma, dtype=gamma.dtype))
Expand Down
16 changes: 16 additions & 0 deletions torch_npu/onnx/wrapper_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ def symbolic(g, self: Tensor, gamma: Tensor, epsilon: float = 1e-6):
return g.op("npu::NPURmsNorm", self, gamma, epsilon_f=epsilon, outputs=2)


class NPUAddRmsNormOP(torch.autograd.Function):

@staticmethod
def forward(ctx, *args, **kwargs):
return torch.ops.npu.npu_add_rms_norm(*args, **kwargs)

@staticmethod
def symbolic(g, x1: Tensor, x2: Tensor, gamma: Tensor, epsilon: float = 1e-6):
return g.op("npu::NPURmsNorm", x1, x2, gamma, epsilon_f=epsilon, outputs=3)


class NPUDiouOP(torch.autograd.Function):

@staticmethod
Expand Down Expand Up @@ -872,6 +883,10 @@ def wrapper_npu_rms_norm(self, gamma, epsilon=1e-6):
return NPURmsNormOP.apply(self, gamma, epsilon)


def wrapper_npu_add_rms_norm(x1, x2, gamma, epsilon=1e-6):
return NPUAddRmsNormOP.apply(x1, x2, gamma, epsilon)


def wrapper_npu_nms_v4(self, scores, max_output_size, iou_threshold, scores_threshold,
pad_to_max_output_size=False):
return NPUNmsV4OP.apply(self, scores, max_output_size,
Expand Down Expand Up @@ -1039,6 +1054,7 @@ def add_onnx_ops():
torch_npu.npu_scatter = wrapper_npu_scatter
torch_npu.npu_lstm = wrapper_npu_lstm
torch_npu.npu_rms_norm = wrapper_npu_rms_norm
torch_npu.npu_add_rms_norm = wrapper_npu_add_rms_norm
torch_npu.npu_lstm_cell = wrapper_npu_lstm_cell
torch_npu.npu_gru = wrapper_npu_gru
torch_npu.npu_dropout_with_add_softmax = wrapper_npu_dropout_with_add_softmax
Expand Down

0 comments on commit 0507300

Please sign in to comment.