diff --git a/src/onnx/parse_groupnorm.cpp b/src/onnx/parse_groupnorm.cpp index a2f673c4b98..ed84d4d7d1f 100644 --- a/src/onnx/parse_groupnorm.cpp +++ b/src/onnx/parse_groupnorm.cpp @@ -34,7 +34,7 @@ struct parse_groupnorm : op_parser { std::vector operators() const { - return {{"GroupNormalization"}, {"GroupNorm"}}; + return {{"GroupNormalization", "GroupNormalization"}, {"GroupNorm", "Contrib_GroupNorm"}}; } instruction_ref parse(const op_desc& opd, @@ -42,7 +42,8 @@ struct parse_groupnorm : op_parser const onnx_parser::node_info& info, std::vector args) const { - auto is_contrib = opd.op_name == "GroupNorm"; + bool is_contrib = (!opd.op_name.compare("Contrib_GroupNorm")); + float epsilon = 1e-5f; if(contains(info.attributes, "epsilon")) { @@ -52,7 +53,9 @@ struct parse_groupnorm : op_parser if(contains(info.attributes, "num_groups") or contains(info.attributes, "groups")) { if (is_contrib) - num_groups = parser.parse_value(info.attributes.at("groups")).at(); + { + num_groups = parser.parse_value(info.attributes.at("num_groups")).at(); + } else num_groups = parser.parse_value(info.attributes.at("num_groups")).at(); } @@ -90,7 +93,7 @@ struct parse_groupnorm : op_parser auto x = args.at(0); if(is_nhwc and is_contrib) { - x = info.add_instruction(make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x); + x = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), x); } auto scale = args.at(1); //gamma in the GroupNorm contrib case @@ -164,7 +167,7 @@ struct parse_groupnorm : op_parser // Convert to NCHW -> NHWC for contrib GroupNorm if(is_nhwc and is_contrib) { - output = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), output); + output = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), output); } return output; } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 09e8874bc5d..5cbd449393e 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -4747,9 +4747,9 @@ def group_norm_contrib_test(x_dims, beta_dims, y_dims, num_groups, - eps_value=1e-5, activation=0, channels_last=0, + eps_value=1e-5, dtype=TensorProto.FLOAT): x = helper.make_tensor_value_info('x', dtype, x_dims) gamma = helper.make_tensor_value_info('gamma', dtype, gamma_dims) @@ -4759,7 +4759,7 @@ def group_norm_contrib_test(x_dims, node = onnx.helper.make_node('GroupNorm', inputs=['x', 'gamma', 'beta'], outputs=['y'], - activaction=activation, + activation=activation, channels_last=channels_last, num_groups=num_groups, epsilon=eps_value) @@ -4769,7 +4769,7 @@ def group_norm_contrib_test(x_dims, @onnx_test() def group_norm_contrib_3d_test(): - return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 1, 0, 0) + return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 0) @onnx_test() @@ -4777,6 +4777,11 @@ def group_norm_contrib_silu_3d_test(): return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 1, 0) +@onnx_test() +def group_norm_contrib_channels_last_3d_test(): + return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 1) + + @onnx_test() def group_norm_contrib_no_activation_attr_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4, 2]) diff --git a/test/onnx/group_norm_contrib_3d_test.onnx b/test/onnx/group_norm_contrib_3d_test.onnx index 3b152febf10..e08a937508d 100644 Binary files a/test/onnx/group_norm_contrib_3d_test.onnx and b/test/onnx/group_norm_contrib_3d_test.onnx differ diff --git a/test/onnx/group_norm_contrib_channels_last_3d_test.onnx b/test/onnx/group_norm_contrib_channels_last_3d_test.onnx new file mode 100644 index 00000000000..1388a7d754b Binary files /dev/null and b/test/onnx/group_norm_contrib_channels_last_3d_test.onnx differ diff --git a/test/onnx/group_norm_contrib_silu_3d_test.onnx b/test/onnx/group_norm_contrib_silu_3d_test.onnx index a945ab9f547..098cc544323 100644 Binary files a/test/onnx/group_norm_contrib_silu_3d_test.onnx and b/test/onnx/group_norm_contrib_silu_3d_test.onnx differ diff --git a/test/onnx/parse/group_norm_contrib_channels_last_3d_test.cpp b/test/onnx/parse/group_norm_contrib_channels_last_3d_test.cpp new file mode 100644 index 00000000000..a522224fc66 --- /dev/null +++ b/test/onnx/parse/group_norm_contrib_channels_last_3d_test.cpp @@ -0,0 +1,35 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(group_norm_contrib_channels_last_3d_test) +{ + migraphx::program p = make_group_norm( + {1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::float_type, "gamma", "beta"); + + auto prog = optimize_onnx("group_norm_contrib_channels_last_3d_test.onnx"); + EXPECT(p == prog); +}