Skip to content

Commit

Permalink
[luci/pass] Revise FuseInstanceNormPass V6
Browse files Browse the repository at this point in the history
This will revise FuseInstanceNormPass version_6 to have const_as_beta as zero values.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark committed Dec 2, 2024
1 parent fb7db64 commit 7ef7f8c
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions compiler/luci/pass/src/FuseInstanceNormPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ namespace
* V V V |
* mul_as_scaled_ifm mul_as_scaled_mean <-------------+
* | |
* | const_as_beta |
* | const_zero |
* | | V
* | +------> sub
* V |
Expand Down Expand Up @@ -721,16 +721,6 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
CHECK_OR_FALSE(is_instance_mean_v2(mean_of_ifm));
CHECK_OR_FALSE(ifm == mean_of_ifm->input());

// If const_as_beta has shape of '1 x chennel x (1 or input last dimension)'
uint32_t input_channel = ifm_circle->dim(1).value();
uint32_t input_last_dim = ifm_circle->dim(2).value();
const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
CHECK_OR_FALSE(const_as_beta);
CHECK_OR_FALSE(const_as_beta->rank() == 3);
CHECK_OR_FALSE(
const_as_beta->dim(0).value() == 1 && const_as_beta->dim(1).value() == input_channel &&
(const_as_beta->dim(2).value() == 1 || const_as_beta->dim(2).value() == input_last_dim));

luci::CircleRsqrt *rsqrt_should_be = nullptr;
luci::CircleMean *mean_of_ifm_should_be = nullptr;

Expand All @@ -741,11 +731,12 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);

// mul_gamma is absent
// const_as_gamma assume to be 1.0
// create 1.0 gamma and 0.0 beta
auto graph = add_as_terminal->graph();
const_as_gamma = make_const_one(graph, 1.0f);
const_as_gamma->name(add_as_terminal->name() + "/gamma");
const_as_beta = make_const_one(graph, 0.0f);
const_as_beta->name(add_as_terminal->name() + "/beta");

_matched = true;
return true;
Expand Down

0 comments on commit 7ef7f8c

Please sign in to comment.