Skip to content

Commit

Permalink
DRAFT FuseInstanceNorm with 3D 1NM shape
Browse files Browse the repository at this point in the history
fix for FuseInstanceNorm with 3D 1NM shape.

Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark committed Dec 2, 2024
1 parent 39c39a3 commit 78ca308
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/circle2circle-dredd-recipe-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Add(Net_InstanceNorm_005 PASS fuse_instnorm)
Add(Net_InstanceNorm_006 PASS fuse_instnorm)
Add(Net_InstanceNorm_007 PASS fuse_instnorm)
Add(Net_InstanceNorm_008 PASS fuse_instnorm)
Add(Net_InstanceNorm_009 PASS fuse_instnorm)
Add(Net_Maximum_Minimum_000 PASS transform_min_max_to_relu6)
Add(Net_Mul_Add_000 PASS remove_unnecessary_add)
Add(Net_Mul_Add_001 PASS remove_unnecessary_add)
Expand Down
1 change: 1 addition & 0 deletions compiler/luci-pass-value-py-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ eval(Net_InstanceNorm_001 fuse_instnorm)
eval(Net_InstanceNorm_002 fuse_instnorm)
eval(Net_InstanceNorm_003 fuse_instnorm)
eval(Net_InstanceNorm_008 fuse_instnorm)
eval(Net_InstanceNorm_009 fuse_instnorm)
eval(Net_Mul_Add_000 remove_unnecessary_add)
eval(Net_Mul_Add_001 remove_unnecessary_add)
eval(Net_Mul_Add_002 remove_unnecessary_add)
Expand Down
12 changes: 12 additions & 0 deletions compiler/luci/pass/src/FuseInstanceNormPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Service/CircleNodeClone.h>
#include <luci/Service/Nodes/CircleConst.h>

#include <cassert>
#include <set>
Expand Down Expand Up @@ -741,6 +742,12 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);

// make clone for shared beta node that gets reshaped in reshape_gamma_beta()
auto beta_origin = luci::get_origin(const_as_beta);
const_as_beta = luci::clone(const_as_beta);
luci::add_origin(const_as_beta, beta_origin);
// NOTE no need to set different name as numbered suffix will be added at export

// mul_gamma is absent
// const_as_gamma assume to be 1.0
auto graph = add_as_terminal->graph();
Expand Down Expand Up @@ -1075,6 +1082,11 @@ uint32_t PostFusion::input_channel(void)
if (input_rank < 1)
return 0;

if (input_rank == 3)
{
// use dim 1
return input->dim(1).value();
}
// assume channel-last
return input->dim(input_rank - 1).value();
}
Expand Down
184 changes: 184 additions & 0 deletions res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.recipe
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#
# This was copied from Net_InstanceNorm_008
# with last dim value > 1
#

operand {
name: "Hole"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 8
}
}
operand {
name: "InstanceNorm/beta"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 1
}
filler {
tag: "gaussian"
arg: "0.0"
arg: "1.0"
}
}
operand {
name: "InstanceNorm/instancenorm/add/y"
type: FLOAT32
shape {
}
filler {
tag: "explicit"
arg: "1e-06"
}
}
operand {
name: "InstanceNorm/moments/variance/reduction_indices"
type: INT32
shape {
dim: 1
}
filler {
tag: "explicit"
arg: "2"
}
}
operand {
name: "InstanceNorm/moments/mean"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 1
}
}
operand {
name: "InstanceNorm/moments/SquaredDifference"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 8
}
}
operand {
name: "InstanceNorm/moments/variance"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 1
}
}
operand {
name: "InstanceNorm/instancenorm/add"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 8
}
}
operand {
name: "InstanceNorm/instancenorm/Rsqrt"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 1
}
}
operand {
name: "InstanceNorm/instancenorm/mul_1"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 1
}
}
operand {
name: "InstanceNorm/instancenorm/mul_2"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 8
}
}
operand {
name: "InstanceNorm/instancenorm/sub"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 1
}
}
operand {
name: "InstanceNorm/instancenorm/add_1"
type: FLOAT32
shape {
dim: 1 dim: 4 dim: 8
}
}
operation {
type: "Mean"
input: "Hole"
input: "InstanceNorm/moments/variance/reduction_indices"
output: "InstanceNorm/moments/mean"
mean_options {
keep_dims: true
}
}
operation {
type: "SquaredDifference"
input: "Hole"
input: "InstanceNorm/moments/mean"
output: "InstanceNorm/moments/SquaredDifference"
}
operation {
type: "Mean"
input: "InstanceNorm/moments/SquaredDifference"
input: "InstanceNorm/moments/variance/reduction_indices"
output: "InstanceNorm/moments/variance"
mean_options {
keep_dims: true
}
}
operation {
type: "Add"
input: "InstanceNorm/moments/variance"
input: "InstanceNorm/instancenorm/add/y"
output: "InstanceNorm/instancenorm/add"
add_options {
activation: NONE
}
}
operation {
type: "Rsqrt"
input: "InstanceNorm/instancenorm/add"
output: "InstanceNorm/instancenorm/Rsqrt"
}
operation {
type: "Mul"
input: "Hole"
input: "InstanceNorm/instancenorm/Rsqrt"
output: "InstanceNorm/instancenorm/mul_1"
mul_options {
activation: NONE
}
}
operation {
type: "Mul"
input: "InstanceNorm/moments/mean"
input: "InstanceNorm/instancenorm/Rsqrt"
output: "InstanceNorm/instancenorm/mul_2"
mul_options {
activation: NONE
}
}
operation {
type: "Sub"
input: "InstanceNorm/beta"
input: "InstanceNorm/instancenorm/mul_2"
output: "InstanceNorm/instancenorm/sub"
sub_options {
activation: NONE
}
}
operation {
type: "Add"
input: "InstanceNorm/instancenorm/mul_1"
input: "InstanceNorm/instancenorm/sub"
output: "InstanceNorm/instancenorm/add_1"
add_options {
activation: NONE
}
}
input: "Hole"
output: "InstanceNorm/instancenorm/add_1"
13 changes: 13 additions & 0 deletions res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.rule
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# To check if this network is converted to circle InstanceNorm op

RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1

RULE "INSTANCE_NORM_EXIST" $(op_count INSTANCE_NORM) '=' 1
RULE "NO_ADD" $(op_count ADD) '=' 0
RULE "NO_MUL" $(op_count MUL) '=' 0
RULE "NO_POW" $(op_count POW) '=' 0
RULE "NO_DIV" $(op_count DIV) '=' 0
RULE "NO_SQUARED_DIFF" $(op_count SQUARED_DIFFERENCE) '=' 0
RULE "NO_MEAN" $(op_count MEAN) '=' 0
RULE "NO_RSQRT" $(op_count RSQRT) '=' 0
RULE "NO_SUB" $(op_count SUB) '=' 0

0 comments on commit 78ca308

Please sign in to comment.