Skip to content

Commit

Permalink
DRAFT CFe fuse Mul Add to Fullyconnected
Browse files Browse the repository at this point in the history
on-going draft to fuse Mul Add to Fullyconnected.

Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark committed Jul 15, 2024
1 parent 8e15af3 commit c433bc8
Show file tree
Hide file tree
Showing 29 changed files with 1,431 additions and 3 deletions.
6 changes: 6 additions & 0 deletions compiler/circle2circle-dredd-recipe-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ Add(MaxPoolWithArgmax_000 PASS resolve_customop_max_pool_with_argmax)
Add(MaxPoolWithArgmax_001 PASS resolve_customop_max_pool_with_argmax)
Add(MaxPoolWithArgmax_002 PASS resolve_customop_max_pool_with_argmax)
Add(Net_Add_FloorMod_Gather_000 PASS remove_gather_guard)
Add(Net_Add_FullyConnected_000 PASS fuse_add_to_fullyconnected_bias)
Add(Net_Add_FullyConnected_001 PASS fuse_add_to_fullyconnected_bias)
Add(Net_Add_FullyConnected_002 PASS fuse_add_to_fullyconnected_bias)
Add(Net_BroadcastTo_AddV2_000 PASS resolve_customop_add)
Add(Net_BroadcastTo_AddV2_001 PASS resolve_customop_add)
Add(Net_BroadcastTo_AddV2_002 PASS resolve_customop_add)
Expand Down Expand Up @@ -61,6 +64,9 @@ Add(Net_Mul_Add_002 PASS remove_unnecessary_add)
Add(Net_Mul_Add_003 PASS remove_unnecessary_add)
Add(Net_Mul_Div_000 PASS fuse_mul_with_div)
Add(Net_Mul_Div_001 PASS fuse_mul_with_div)
Add(Net_Mul_FullyConnected_000 PASS fuse_mul_to_fullyconnected_weights fold_mul)
Add(Net_Mul_FullyConnected_001 PASS fuse_mul_to_fullyconnected_weights fold_mul)
Add(Net_Mul_FullyConnected_002 PASS fuse_mul_to_fullyconnected_weights fold_mul)
Add(Net_Preactivation_BN_000 PASS fuse_preactivation_batchnorm)
Add(Net_Reshape_Reshape_000 PASS remove_redundant_reshape)
Add(Net_Shape_Add_000 PASS fold_shape)
Expand Down
11 changes: 11 additions & 0 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ int entry(int argc, char **argv)
add_switch(arser, "--fold_fully_connected",
"This will fold FullyConnected operator with constant inputs");
add_switch(arser, "--fold_gather", "This will fold Gather operator");
add_switch(arser, "--fold_mul", "This will fold Mul operator");
add_switch(arser, "--fold_reshape", "This will fold Reshape operator");
add_switch(arser, "--fold_shape", "This will fold Shape operator");
add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator");
Expand All @@ -105,6 +106,10 @@ int entry(int argc, char **argv)
add_switch(arser, "--fuse_batchnorm_with_tconv",
"This will fuse BatchNorm operators to Transposed Convolution operator");
add_switch(arser, "--fuse_bcq", "This will fuse operators and apply Binary Coded Quantization");
add_switch(arser, "--fuse_add_to_fullyconnected_bias",
"This will fuse Add to following FullyConnected bias");
add_switch(arser, "--fuse_mul_to_fullyconnected_weights",
"This will fuse Mul to following FullyConnected weights");
add_switch(arser, "--fuse_instnorm", "This will fuse operators to InstanceNorm operator");
add_switch(arser, "--fuse_mean_with_mean",
"This will fuse two Mean operations when they follow one by one. This will fold them "
Expand Down Expand Up @@ -275,6 +280,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FoldFullyConnected);
if (arser.get<bool>("--fold_gather"))
options->enable(Algorithms::FoldGather);
if (arser.get<bool>("--fold_mul"))
options->enable(Algorithms::FoldMul);
if (arser.get<bool>("--fold_reshape"))
options->enable(Algorithms::FoldReshape);
if (arser.get<bool>("--fold_shape"))
Expand Down Expand Up @@ -303,6 +310,10 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FuseBatchNormWithDwConv);
if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
options->enable(Algorithms::FuseBatchNormWithTConv);
if (arser.get<bool>("--fuse_add_to_fullyconnected_bias"))
options->enable(Algorithms::FuseAddToFullyConnectedBias);
if (arser.get<bool>("--fuse_mul_to_fullyconnected_weights"))
options->enable(Algorithms::FuseMulToFullyConnectedWeights);
if (arser.get<bool>("--fuse_slice_with_tconv"))
options->enable(Algorithms::FuseSliceWithTConv);
if (arser.get<bool>("--fuse_bcq"))
Expand Down
8 changes: 7 additions & 1 deletion compiler/luci-pass-value-py-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Format:
# eval(MODEL PASS)
# MODEL: tflite model file name in build/compiler/common-artifacts folder.
# PASS: Optimization Pass to test. Supports only one Pass for now.
# PASS: Optimization Pass to test. Supports one or Passes.
#

# eval(Net_Preactivation_BN_000 fuse_preactivation_batchnorm) : value diff exist
# --> https://github.com/Samsung/ONE/issues/5782
eval(FullyConnected_007 replace_non_const_fc_with_batch_matmul)
eval(HardSwish_001 decompose_hardswish)
eval(Net_Add_FloorMod_Gather_000 remove_gather_guard)
eval(Net_Add_FullyConnected_000 fuse_add_to_fullyconnected_bias)
eval(Net_Add_FullyConnected_001 fuse_add_to_fullyconnected_bias)
eval(Net_Add_FullyConnected_002 fuse_add_to_fullyconnected_bias)
eval(Net_Conv_Add_000 fuse_add_with_conv)
eval(Net_Conv_Add_001 fuse_add_with_conv)
# eval(Net_Conv_Add_002 fuse_add_with_conv) --> Conv2D w/o bias fails in tflite interpreter
Expand Down Expand Up @@ -40,6 +43,9 @@ eval(Net_Mul_Add_002 remove_unnecessary_add)
eval(Net_Mul_Add_003 remove_unnecessary_add)
eval(Net_Mul_Div_000 fuse_mul_with_div)
eval(Net_Mul_Div_001 fuse_mul_with_div)
eval(Net_Mul_FullyConnected_000 fuse_mul_to_fullyconnected_weights)
eval(Net_Mul_FullyConnected_001 fuse_mul_to_fullyconnected_weights)
eval(Net_Mul_FullyConnected_002 fuse_mul_to_fullyconnected_weights)
eval(Net_Reshape_Mean_000 forward_reshape_to_unaryop)
eval(Net_Reshape_Neg_000 forward_reshape_to_unaryop)
eval(Net_Reshape_Reshape_000 remove_redundant_reshape)
Expand Down
9 changes: 7 additions & 2 deletions compiler/luci-pass-value-py-test/test_luci_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ def luci_eval_verify(test_name,
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg
elif output_details["dtype"] == np.float32:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32), err_msg
diff_comp = np.allclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32)
if not diff_comp:
print("\r\ntflite:\r\n", intp_output_data, flush=True)
print("\r\ncircle:\r\n", luci_output_data, flush=True)
print("\r\nDiff:\r\n", intp_output_data - luci_output_data, flush=True)
assert diff_comp, err_msg
elif output_details["dtype"] == np.int64:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg
Expand Down
3 changes: 3 additions & 0 deletions compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ class CircleOptimizer final
{
enum Algorithm
{
FuseAddToFullyConnectedBias,
FuseAddWithConv,
FuseAddWithFullyConnected,
FuseAddWithTConv,
FuseBatchNormWithConv,
FuseBatchNormWithDwConv,
FuseBatchNormWithTConv,
FuseMulToFullyConnectedWeights,
FuseSliceWithTConv,
FuseBCQ,
FuseHorizontalFullyConnected,
Expand All @@ -61,6 +63,7 @@ class CircleOptimizer final
FoldFullyConnected,
FoldDequantize,
FoldGather,
FoldMul,
FoldReshape,
FoldShape,
FoldSparseToDense,
Expand Down
38 changes: 38 additions & 0 deletions compiler/luci/pass/include/luci/Pass/FoldMulPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FOLD_MUL_PASS_H__
#define __LUCI_FOLD_MUL_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fold Mul to a constant tensor
*
*/
struct FoldMulPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FoldMulPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FOLD_MUL_PASS_H__
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FUSE_ADD_TO_FULLY_CONNECTED_BIAS_PASS_H__
#define __LUCI_FUSE_ADD_TO_FULLY_CONNECTED_BIAS_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fuse Add to following FC bias
*/
struct FuseAddToFullyConnectedBiasPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FuseAddToFullyConnectedBiasPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FUSE_ADD_TO_FULLY_CONNECTED_BIAS_PASS_H__
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__
#define __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fuse Mul into following FullyConnected
*/
struct FuseMulToFullyConnectedWeightsPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FuseMulToFullyConnectedWeightsPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__
15 changes: 15 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,23 @@
#include "luci/Pass/FoldDequantizePass.h"
#include "luci/Pass/FoldFullyConnectedPass.h"
#include "luci/Pass/FoldGatherPass.h"
#include "luci/Pass/FoldMulPass.h"
#include "luci/Pass/FoldReshapePass.h"
#include "luci/Pass/FoldShapePass.h"
#include "luci/Pass/FoldSparseToDensePass.h"
#include "luci/Pass/FoldSqueezePass.h"
#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
#include "luci/Pass/ForwardTransposeOpPass.h"
#include "luci/Pass/FuseActivationFunctionPass.h"
#include "luci/Pass/FuseAddToFullyConnectedBiasPass.h"
#include "luci/Pass/FuseAddWithConvPass.h"
#include "luci/Pass/FuseAddWithFullyConnectedPass.h"
#include "luci/Pass/FuseAddWithTConvPass.h"
#include "luci/Pass/FuseBatchNormWithConvPass.h"
#include "luci/Pass/FuseBatchNormWithDwConvPass.h"
#include "luci/Pass/FuseBatchNormWithTConvPass.h"
#include "luci/Pass/FuseBCQPass.h"
#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h"
#include "luci/Pass/FuseInstanceNormPass.h"
#include "luci/Pass/FuseMeanWithMeanPass.h"
#include "luci/Pass/FuseMulWithConvPass.h"
Expand Down Expand Up @@ -333,6 +336,14 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
}
if (_options->query(Options::Algorithm::FuseAddToFullyConnectedBias))
{
phase.emplace_back(std::make_unique<FuseAddToFullyConnectedBiasPass>());
}
if (_options->query(Options::Algorithm::FuseMulToFullyConnectedWeights))
{
phase.emplace_back(std::make_unique<FuseMulToFullyConnectedWeightsPass>());
}
if (_options->query(Options::Algorithm::FusePRelu))
{
phase.emplace_back(std::make_unique<FusePReluPass>());
Expand Down Expand Up @@ -381,6 +392,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
}
if (_options->query(Options::Algorithm::FoldMul))
{
phase.emplace_back(std::make_unique<luci::FoldMulPass>());
}
if (_options->query(Options::Algorithm::FoldReshape))
{
phase.emplace_back(std::make_unique<luci::FoldReshapePass>());
Expand Down
Loading

0 comments on commit c433bc8

Please sign in to comment.