From 6fb3c7b497417580b19615aed957b99e5ba3a2d1 Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Tue, 27 Aug 2024 17:35:52 +0900 Subject: [PATCH] [luci/pass] Support InstanceNorm with 3D input (#13778) This commit will support InstanceNorm with 3D input. ONE-DCO-1.0-Signed-off-by: ragmani --- .../luci/pass/src/FuseInstanceNormPass.cpp | 189 +++++++++++++++++- 1 file changed, 188 insertions(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.cpp index 10a651e35e7..5427e1fe69c 100644 --- a/compiler/luci/pass/src/FuseInstanceNormPass.cpp +++ b/compiler/luci/pass/src/FuseInstanceNormPass.cpp @@ -85,6 +85,46 @@ bool is_instance_mean_v1(luci::CircleMean *mean) return mean->keep_dims(); } +bool is_instance_mean_v2(luci::CircleMean *mean) +{ + // + // CHECK 1) input is rank 3 + // + auto input = loco::must_cast(mean->input()); + if (input->shape_status() != luci::ShapeStatus::VALID) + return false; + if (input->rank() != 3) + return false; + + // + // CHECK 2) 'reduction indices' is CircleConst of value [2], that is last dim of rank 3 + // + // TODO Support non-Const case? + auto red_indices = dynamic_cast(mean->reduction_indices()); + if (not red_indices) + return false; + if (red_indices->rank() != 1) + return false; + std::set red_indices_set; + { + // TODO Currently only support S32, support other types + assert(red_indices->dtype() == loco::DataType::S32); + for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i) + red_indices_set.insert(red_indices->at(i)); + } + if (red_indices_set.size() != 1) + return false; + if (red_indices_set.find(2) == red_indices_set.end()) + return false; + + // + // CHECK 3) keep_dims == true (?) + // + // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false' + // TODO Check this fact, and if true, return true regardless of keep_dims + return mean->keep_dims(); +} + /// @return true When node has the shape of 1D channel_size bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size) { @@ -296,6 +336,40 @@ namespace * | * V * [Out] + *------------------------------------------------------------------- + * Version_6 (Same as Version_5, For only 3D I/O) + * [In] + * | + * V + * +----------- ifm -----+ (reduction indicies) + * | | | | + * | | V V + * | | mean_of_ifm ----------------+ + * | V | | + * | sqdiff <--+ (reduction indicies) | + * | | | | + * | V | | + * | mean_as_variance <---+ const_as_epsilon | + * | | | | + * | V | | + * | add_as_variance <--------+ | + * | | | + * | V | + * | rsqrt | + * | | | + * | +--+--+ | + * | | | | + * V V V | + * mul_as_scaled_ifm mul_as_scaled_mean <-------------+ + * | | + * | const_as_beta | + * | | V + * | +------> sub + * V | + * add_as_terminal <----------+ + * | + * V + * [Out] */ class InstanceNormPattern final { @@ -308,6 +382,7 @@ class InstanceNormPattern final Version_3, Version_4, Version_5, + Version_6, // For only 3D I/O }; InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv) @@ -615,6 +690,67 @@ template <> bool InstanceNormPattern::match bool InstanceNormPattern::match() +{ + CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal)); + CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm)); + + auto ifm_circle = loco::must_cast(ifm); + CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID); + CHECK_OR_FALSE(ifm_circle->rank() == 3); + CHECK_OR_FALSE((ifm_circle->dim(1).known())); + + add_as_variance = dynamic_cast(rsqrt->x()); + CHECK_OR_FALSE(add_as_variance); + + CHECK_OR_FALSE( + luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); + + CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); + // TODO Support regarding broadcast + CHECK_OR_FALSE(const_as_epsilon->size() == 1); + + CHECK_OR_FALSE(is_instance_mean_v2(mean_as_variance)); + + sqdiff = dynamic_cast(mean_as_variance->input()); + CHECK_OR_FALSE(sqdiff); + + loco::Node *ifm_should_be = nullptr; + CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff)); + CHECK_OR_FALSE(ifm == ifm_should_be); + CHECK_OR_FALSE(is_instance_mean_v2(mean_of_ifm)); + CHECK_OR_FALSE(ifm == mean_of_ifm->input()); + + // If const_as_beta has shape of '1 x chennel x (1 or input last dimension)' + uint32_t input_channel = ifm_circle->dim(1).value(); + uint32_t input_last_dim = ifm_circle->dim(2).value(); + const_as_beta = dynamic_cast(sub->x()); + CHECK_OR_FALSE(const_as_beta); + CHECK_OR_FALSE(const_as_beta->rank() == 3); + CHECK_OR_FALSE( + const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel && + (const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim)); + + luci::CircleRsqrt *rsqrt_should_be = nullptr; + luci::CircleMean *mean_of_ifm_should_be = nullptr; + + mul_as_scaled_mean = dynamic_cast(sub->y()); + CHECK_OR_FALSE(mul_as_scaled_mean); + CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be) + .with_commutative_args_of(mul_as_scaled_mean)); + CHECK_OR_FALSE(rsqrt == rsqrt_should_be); + CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be); + + // mul_gamma is absent + // const_as_gamma assume to be 1.0 + auto graph = add_as_terminal->graph(); + const_as_gamma = make_const_one(graph, 1.0f); + const_as_gamma->name(add_as_terminal->name() + "/gamma"); + + _matched = true; + return true; +} + bool InstanceNormPattern::matched() { if (_matched) @@ -634,6 +770,8 @@ bool InstanceNormPattern::matched() return match(); case PatternVersion::Version_5: return match(); + case PatternVersion::Version_6: + return match(); default: break; @@ -843,6 +981,31 @@ template <> void FuseInstanceNorm::apply void FuseInstanceNorm::apply() +{ + auto graph = _p.add_as_terminal->graph(); + + reshape_gamma_beta(); + + auto instance_norm = create_inst_norm(graph); + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p.mean_of_ifm), + luci::get_origin(_p.sqdiff), + luci::get_origin(_p.mean_as_variance), + luci::get_origin(_p.add_as_variance), + luci::get_origin(_p.rsqrt), + luci::get_origin(_p.mul_as_scaled_ifm), + luci::get_origin(_p.mul_as_scaled_mean), + luci::get_origin(_p.sub), + luci::get_origin(_p.add_as_terminal)}; + + luci::add_origin(instance_norm, luci::composite_origin(origin_vec)); + + replace(_p.add_as_terminal).with(instance_norm); +} + void FuseInstanceNorm::apply() { assert(_p.matched()); @@ -864,6 +1027,9 @@ void FuseInstanceNorm::apply() case InstanceNormPattern::PatternVersion::Version_5: apply(); break; + case InstanceNormPattern::PatternVersion::Version_6: + apply(); + break; default: break; @@ -1006,12 +1172,33 @@ bool is_add_input_mul_const(luci::CircleAdd *add) return luci::fill(&p_mul, &p_const).with_commutative_args_of(add); } +bool is_add_input_mul_sub3d(luci::CircleAdd *add) +{ + luci::CircleMul *p_mul = nullptr; + luci::CircleSub *p_sub = nullptr; + + if (!luci::fill(&p_mul, &p_sub).with_commutative_args_of(add)) + return false; + + auto sub = dynamic_cast(add->y()); + if (sub == nullptr) + return false; + + auto const_as_beta = dynamic_cast(sub->x()); + if (const_as_beta == nullptr || const_as_beta->rank() != 3) + return false; + + return true; +} + bool fuse_instance_norm(luci::CircleAdd *add) { InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1; if (is_add_input_mul_const(add)) pv = InstanceNormPattern::PatternVersion::Version_2; + else if (is_add_input_mul_sub3d(add)) + pv = InstanceNormPattern::PatternVersion::Version_6; InstanceNormPattern pattern(add, pv); if (pattern.matched()) @@ -1080,7 +1267,7 @@ bool FuseInstanceNormPass::run(loco::Graph *g) { bool changed = false; - // Check Version_1, Version_2, Version_3, Version_5 + // Check Version_1, Version_2, Version_3, Version_5, Version_6 for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto add = dynamic_cast(node);