From b700147921296a9371dc76410f276aba8a83f1f1 Mon Sep 17 00:00:00 2001 From: SaeHie Park Date: Mon, 15 Jul 2024 23:54:02 +0000 Subject: [PATCH] [luci/pass] Introduce FuseMulToFullyConnectedWeightsPass This will introduce FuseMulToFullyConnectedWeightsPass which will fuse Mul to following FullyConnected weights if possible. ONE-DCO-1.0-Signed-off-by: SaeHie Park --- .../luci/pass/include/luci/CircleOptimizer.h | 1 + .../Pass/FuseMulToFullyConnectedWeightsPass.h | 37 ++++ compiler/luci/pass/src/CircleOptimizer.cpp | 5 + .../FuseMulToFullyConnectedWeightsPass.cpp | 120 +++++++++++++ ...useMulToFullyConnectedWeightsPass.test.cpp | 160 ++++++++++++++++++ 5 files changed, 323 insertions(+) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseMulToFullyConnectedWeightsPass.h create mode 100644 compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.cpp create mode 100644 compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.test.cpp diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 01b43a72844..611fe3fd6e5 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -40,6 +40,7 @@ class CircleOptimizer final FuseBatchNormWithConv, FuseBatchNormWithDwConv, FuseBatchNormWithTConv, + FuseMulToFullyConnectedWeights, FuseSliceWithTConv, FuseBCQ, FuseHorizontalFullyConnected, diff --git a/compiler/luci/pass/include/luci/Pass/FuseMulToFullyConnectedWeightsPass.h b/compiler/luci/pass/include/luci/Pass/FuseMulToFullyConnectedWeightsPass.h new file mode 100644 index 00000000000..583f21ef82c --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseMulToFullyConnectedWeightsPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ +#define __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Mul into following FullyConnected + */ +struct FuseMulToFullyConnectedWeightsPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseMulToFullyConnectedWeightsPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 3c94311c8d5..09a830e10d4 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -40,6 +40,7 @@ #include "luci/Pass/FuseBatchNormWithDwConvPass.h" #include "luci/Pass/FuseBatchNormWithTConvPass.h" #include "luci/Pass/FuseBCQPass.h" +#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" #include "luci/Pass/FuseInstanceNormPass.h" #include "luci/Pass/FuseMeanWithMeanPass.h" #include "luci/Pass/FuseMulWithConvPass.h" @@ -333,6 +334,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseMulToFullyConnectedWeights)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FusePRelu)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.cpp b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.cpp new file mode 100644 index 00000000000..1f41d16f05b --- /dev/null +++ b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.cpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" + +#include +#include + +#include "helpers/NodeFiller.h" + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +namespace +{ + +/** + * Fuse Mul to following FullyConnected if possible + * + * BEFORE + * | + * [CircleMul] [CircleConst] [CircleConst] + * | | | + * [CircleFullyConnected] ----------+ + * | + * + * AFTER + * | + * | [CircleConst] [CircleConst] + * | | | + * | [CircleMul] [CircleConst] [CircleMul] + * | | | + * [CircleFullyConnected] ------------+ + * | + * + */ +bool fuse_fc_with_mul(luci::CircleFullyConnected *fc) +{ + CHECK_OR_FALSE(fc); + + // check input is Mul + auto mul = dynamic_cast(fc->input()); + CHECK_OR_FALSE(mul); + // conditions of Mul, FC: to expect constant folding, support only F32 + CHECK_OR_FALSE(mul->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(mul->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(fc->dtype() == loco::DataType::FLOAT32); + // support weight with constant + auto weights = dynamic_cast(fc->weights()); + CHECK_OR_FALSE(weights); + + // Check multiplication of Mul is constant + luci::CircleNode *mul_input = nullptr; + luci::CircleConst *mul_scale = nullptr; + CHECK_OR_FALSE(luci::fill(&mul_input, &mul_scale).with_commutative_args_of(mul)); + // support only 1D constant + CHECK_OR_FALSE(mul_scale->rank() == 1); + + auto graph = fc->graph(); + + auto fc_weights = graph->nodes()->create(); + fc_weights->x(weights); + fc_weights->y(mul_scale); + fc_weights->fusedActivationFunction(luci::FusedActFunc::NONE); + fc_weights->name(mul->name() + "_" + fc->name() + "_weight"); + luci::add_origin(fc_weights, + luci::composite_origin({luci::get_origin(mul), luci::get_origin(weights), + luci::get_origin(mul_scale)})); + + auto fc_new = graph->nodes()->create(); + fc_new->input(mul_input); + fc_new->weights(fc_weights); + fc_new->bias(fc->bias()); + fc_new->weights_format(fc->weights_format()); + fc_new->keep_num_dims(fc->keep_num_dims()); + fc_new->fusedActivationFunction(fc->fusedActivationFunction()); + fc_new->name(fc->name()); + luci::add_origin(fc_new, luci::get_origin(fc)); + + replace(fc).with(fc_new); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseMulToFullyConnectedWeightsPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast(node); + if (not fc) + continue; + + if (fuse_fc_with_mul(fc)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.test.cpp b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.test.cpp new file mode 100644 index 00000000000..2cb7a4e9f81 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulToFullyConnectedWeightsPass.test.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +template class FuseMulToFullyConnectedWeightsPassTestGraph : public TestIOGraph +{ +public: + FuseMulToFullyConnectedWeightsPassTestGraph() = default; + + void init(void) + { + TestIOGraph::init({3, 4}, {3, 6}); + + _mul = g()->nodes()->create(); + _mul_s = g()->nodes()->create(); + _fc = g()->nodes()->create(); + _fc_w = g()->nodes()->create(); + _fc_b = g()->nodes()->create(); + + _mul->name("mul"); + _mul_s->name("mul_s"); + _fc->name("fc"); + _fc_w->name("fc_w"); + _fc_b->name("fc_b"); + + _mul->dtype(DT); + _fc->dtype(DT); + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + + _mul_s->rank(1); + _mul_s->dim(0) = 3; + _mul_s->dtype(DT); + _mul_s->size
(3); + for (uint32_t i = 0; i < 3; ++i) + { + _mul_s->at
(0) = 1.0f; + } + + _fc_w->rank(2); + _fc_w->dim(0) = 3; + _fc_w->dim(1) = 4; + _fc_w->dtype(DT); + _fc_w->size
(4 * 6); + for (uint32_t i = 0; i < 4 * 6; ++i) + { + _fc_w->at
(0) = 1.0f; + } + + _fc_b->rank(1); + _fc_b->dim(0) = 6; + _fc_b->dtype(DT); + _fc_b->size
(6); + for (uint32_t i = 0; i < 6; ++i) + { + _fc_b->at
(0) = 1.0f; + } + + _mul->x(input()); + _mul->y(_mul_s); + _fc->input(_mul); + _fc->weights(_fc_b); + _fc->bias(_fc_b); + + output()->from(_fc); + } + + luci::CircleMul *_mul = nullptr; + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleConst *_mul_s = nullptr; + luci::CircleConst *_fc_w = nullptr; + luci::CircleConst *_fc_b = nullptr; +}; + +class FuseMulToFullyConnectedWeightsPassTest : public ::testing::Test +{ +public: + FuseMulToFullyConnectedWeightsPassTest() = default; + +protected: + FuseMulToFullyConnectedWeightsPassTestGraph _graph; + luci::FuseMulToFullyConnectedWeightsPass _pass; +}; + +class FuseMulToFullyConnectedWeightsPassS32Test : public ::testing::Test +{ +public: + FuseMulToFullyConnectedWeightsPassS32Test() = default; + +protected: + FuseMulToFullyConnectedWeightsPassTestGraph _graph; + luci::FuseMulToFullyConnectedWeightsPass _pass; +}; + +} // namespace + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, name) +{ + auto const name = _pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, fuse_mul_to_fc_weights) +{ + _graph.init(); + + EXPECT_TRUE(_pass.run(_graph.g())); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, mul_fused_act_NEG) +{ + _graph.init(); + + _graph._mul->fusedActivationFunction(luci::FusedActFunc::RELU); + + EXPECT_FALSE(_pass.run(_graph.g())); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassTest, mul_d2_NEG) +{ + _graph.init(); + + _graph._mul_s->rank(2); + _graph._mul_s->dim(0) = 1; + _graph._mul_s->dim(1) = 3; + + EXPECT_FALSE(_pass.run(_graph.g())); +} + +TEST_F(FuseMulToFullyConnectedWeightsPassS32Test, dtype_s32_NEG) +{ + _graph.init(); + + EXPECT_FALSE(_pass.run(_graph.g())); +}