From e70ca61e7f977e7b291cae9e2785296e29ba9a4a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 23 Dec 2024 04:53:27 +0000 Subject: [PATCH] Fix naming for contrib op check, add test for channels last and change inputs for generated onnx to handle input activation. TODO Need to finish parser tests --- src/onnx/parse_groupnorm.cpp | 13 ++++--- test/onnx/gen_onnx.py | 11 ++++-- test/onnx/group_norm_contrib_3d_test.onnx | Bin 265 -> 267 bytes ...up_norm_contrib_channels_last_3d_test.onnx | Bin 0 -> 295 bytes .../onnx/group_norm_contrib_silu_3d_test.onnx | Bin 275 -> 277 bytes ...oup_norm_contrib_channels_last_3d_test.cpp | 35 ++++++++++++++++++ 6 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 test/onnx/group_norm_contrib_channels_last_3d_test.onnx create mode 100644 test/onnx/parse/group_norm_contrib_channels_last_3d_test.cpp diff --git a/src/onnx/parse_groupnorm.cpp b/src/onnx/parse_groupnorm.cpp index a2f673c4b98..ed84d4d7d1f 100644 --- a/src/onnx/parse_groupnorm.cpp +++ b/src/onnx/parse_groupnorm.cpp @@ -34,7 +34,7 @@ struct parse_groupnorm : op_parser { std::vector operators() const { - return {{"GroupNormalization"}, {"GroupNorm"}}; + return {{"GroupNormalization", "GroupNormalization"}, {"GroupNorm", "Contrib_GroupNorm"}}; } instruction_ref parse(const op_desc& opd, @@ -42,7 +42,8 @@ struct parse_groupnorm : op_parser const onnx_parser::node_info& info, std::vector args) const { - auto is_contrib = opd.op_name == "GroupNorm"; + bool is_contrib = (!opd.op_name.compare("Contrib_GroupNorm")); + float epsilon = 1e-5f; if(contains(info.attributes, "epsilon")) { @@ -52,7 +53,9 @@ struct parse_groupnorm : op_parser if(contains(info.attributes, "num_groups") or contains(info.attributes, "groups")) { if (is_contrib) - num_groups = parser.parse_value(info.attributes.at("groups")).at(); + { + num_groups = parser.parse_value(info.attributes.at("num_groups")).at(); + } else num_groups = parser.parse_value(info.attributes.at("num_groups")).at(); } @@ -90,7 +93,7 @@ struct parse_groupnorm : op_parser auto x = args.at(0); if(is_nhwc and is_contrib) { - x = info.add_instruction(make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x); + x = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), x); } auto scale = args.at(1); //gamma in the GroupNorm contrib case @@ -164,7 +167,7 @@ struct parse_groupnorm : op_parser // Convert to NCHW -> NHWC for contrib GroupNorm if(is_nhwc and is_contrib) { - output = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), output); + output = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), output); } return output; } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 09e8874bc5d..5cbd449393e 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -4747,9 +4747,9 @@ def group_norm_contrib_test(x_dims, beta_dims, y_dims, num_groups, - eps_value=1e-5, activation=0, channels_last=0, + eps_value=1e-5, dtype=TensorProto.FLOAT): x = helper.make_tensor_value_info('x', dtype, x_dims) gamma = helper.make_tensor_value_info('gamma', dtype, gamma_dims) @@ -4759,7 +4759,7 @@ def group_norm_contrib_test(x_dims, node = onnx.helper.make_node('GroupNorm', inputs=['x', 'gamma', 'beta'], outputs=['y'], - activaction=activation, + activation=activation, channels_last=channels_last, num_groups=num_groups, epsilon=eps_value) @@ -4769,7 +4769,7 @@ def group_norm_contrib_test(x_dims, @onnx_test() def group_norm_contrib_3d_test(): - return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 1, 0, 0) + return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 0) @onnx_test() @@ -4777,6 +4777,11 @@ def group_norm_contrib_silu_3d_test(): return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 1, 0) +@onnx_test() +def group_norm_contrib_channels_last_3d_test(): + return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 1) + + @onnx_test() def group_norm_contrib_no_activation_attr_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4, 2]) diff --git a/test/onnx/group_norm_contrib_3d_test.onnx b/test/onnx/group_norm_contrib_3d_test.onnx index 3b152febf10f82f8f66aeaea0cc4cd010b261c63..e08a937508d7086140db20e9dd23472f50a9174f 100644 GIT binary patch delta 83 zcmeBV>Ski%;F{Pcqigkykt>&rv4V>=Jux>ok&7iMwIor9u~LcCy(qu5z%RcjS4)tK nD>1nwvn+9gEd=87EfvGKv8JwNV%J delta 81 zcmeBX>SSW#;G8%?M$hUABUd&TV+9v$dSY&FA{R?iYDuCHW2F+Odr^LAfnR=6u9grN kcVco$W?5qLL~jLtJ}&mug5u1a{5%PU1&mA+YkL`m0e*KDga7~l diff --git a/test/onnx/group_norm_contrib_channels_last_3d_test.onnx b/test/onnx/group_norm_contrib_channels_last_3d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1388a7d754ba6164e3c490e2783f23f2b9c60a92 GIT binary patch literal 295 zcmd`JMkXy0E?yMl zBp5-Wf?VvW1;v>;`FWyij;fn4U}Oa8%`44~2Ro-&f(fXKc(+A~105zL#3jJND8$3X o#K8!}EI`Z@B@A`05EmB*P!b|21a&-;SQ5~LN^B-LF$stP06#rTdjJ3c literal 0 HcmV?d00001 diff --git a/test/onnx/group_norm_contrib_silu_3d_test.onnx b/test/onnx/group_norm_contrib_silu_3d_test.onnx index a945ab9f5470fb41375db9d5b6beac15e3562c15..098cc54432394f42c68e84e1eaea543b0da82cdb 100644 GIT binary patch delta 84 zcmbQtG?j^ogKOesMOLfVj9j@BHRP=YxwsOOOESw6OEUBGBp4SkGHHo$@g`>^=H;d4 m6vyWz7MDmcfJ6nk*i#FNGjsCuMAsZuH($WWII&?8qZk0)8yG_X delta 82 zcmbQrG?|HsgLC3kMOLeqj9l3hHRLUYxVRIOOESwqM1GzG!vaPoEfFr> +#include + +TEST_CASE(group_norm_contrib_channels_last_3d_test) +{ + migraphx::program p = make_group_norm( + {1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::float_type, "gamma", "beta"); + + auto prog = optimize_onnx("group_norm_contrib_channels_last_3d_test.onnx"); + EXPECT(p == prog); +}