-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[luci/pass] Introduce FuseMulToFullyConnectedWeightsPass #13439
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ | ||
#define __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ | ||
|
||
#include <logo/Pass.h> | ||
|
||
namespace luci | ||
{ | ||
|
||
/** | ||
* @brief Class to fuse Mul into following FullyConnected | ||
*/ | ||
struct FuseMulToFullyConnectedWeightsPass final : public logo::Pass | ||
{ | ||
const char *name(void) const final { return "luci::FuseMulToFullyConnectedWeightsPass"; } | ||
|
||
bool run(loco::Graph *g) final; | ||
}; | ||
|
||
} // namespace luci | ||
|
||
#endif // __LUCI_FUSE_MUL_TO_FULLY_CONNECTED_WEIGHTS_PASS_H__ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
#include <luci/Profile/CircleNodeOrigin.h> | ||
|
||
#include "helpers/NodeFiller.h" | ||
|
||
#define CHECK_OR_FALSE(condition) \ | ||
if (not(condition)) \ | ||
return false; | ||
|
||
namespace | ||
{ | ||
|
||
/** | ||
* Fuse Mul to following FullyConnected if possible | ||
* | ||
* BEFORE | ||
* | | ||
* [CircleMul] [CircleConst] [CircleConst] | ||
* | | | | ||
* [CircleFullyConnected] ----------+ | ||
* | | ||
* | ||
* AFTER | ||
* | | ||
* | [CircleConst] [CircleConst] | ||
* | | | | ||
* | [CircleMul] [CircleConst] [CircleMul] | ||
* | | | | ||
* [CircleFullyConnected] ------------+ | ||
* | | ||
* | ||
*/ | ||
bool fuse_fc_with_mul(luci::CircleFullyConnected *fc) | ||
{ | ||
CHECK_OR_FALSE(fc); | ||
|
||
// check input is Mul | ||
auto mul = dynamic_cast<luci::CircleMul *>(fc->input()); | ||
CHECK_OR_FALSE(mul); | ||
// conditions of Mul, FC: to expect constant folding, support only F32 | ||
CHECK_OR_FALSE(mul->dtype() == loco::DataType::FLOAT32); | ||
CHECK_OR_FALSE(mul->fusedActivationFunction() == luci::FusedActFunc::NONE); | ||
CHECK_OR_FALSE(fc->dtype() == loco::DataType::FLOAT32); | ||
// support weight with constant | ||
auto weights = dynamic_cast<luci::CircleConst *>(fc->weights()); | ||
CHECK_OR_FALSE(weights); | ||
|
||
// Check multiplication of Mul is constant | ||
luci::CircleNode *mul_input = nullptr; | ||
luci::CircleConst *mul_scale = nullptr; | ||
CHECK_OR_FALSE(luci::fill(&mul_input, &mul_scale).with_commutative_args_of(mul)); | ||
// support only 1D constant | ||
CHECK_OR_FALSE(mul_scale->rank() == 1); | ||
|
||
auto graph = fc->graph(); | ||
|
||
auto fc_weights = graph->nodes()->create<luci::CircleMul>(); | ||
fc_weights->x(weights); | ||
fc_weights->y(mul_scale); | ||
fc_weights->fusedActivationFunction(luci::FusedActFunc::NONE); | ||
fc_weights->name(mul->name() + "_" + fc->name() + "_weight"); | ||
luci::add_origin(fc_weights, | ||
luci::composite_origin({luci::get_origin(mul), luci::get_origin(weights), | ||
luci::get_origin(mul_scale)})); | ||
|
||
auto fc_new = graph->nodes()->create<luci::CircleFullyConnected>(); | ||
fc_new->input(mul_input); | ||
fc_new->weights(fc_weights); | ||
fc_new->bias(fc->bias()); | ||
fc_new->weights_format(fc->weights_format()); | ||
fc_new->keep_num_dims(fc->keep_num_dims()); | ||
fc_new->fusedActivationFunction(fc->fusedActivationFunction()); | ||
fc_new->name(fc->name()); | ||
luci::add_origin(fc_new, luci::get_origin(fc)); | ||
|
||
replace(fc).with(fc_new); | ||
|
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
namespace luci | ||
{ | ||
|
||
bool FuseMulToFullyConnectedWeightsPass::run(loco::Graph *g) | ||
{ | ||
bool changed = false; | ||
for (auto node : loco::active_nodes(loco::output_nodes(g))) | ||
{ | ||
auto fc = dynamic_cast<luci::CircleFullyConnected *>(node); | ||
if (not fc) | ||
continue; | ||
|
||
if (fuse_fc_with_mul(fc)) | ||
changed = true; | ||
} | ||
|
||
return changed; | ||
} | ||
|
||
} // namespace luci |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,160 @@ | ||||||||||||||||||||||||||||||||||||||||||
/* | ||||||||||||||||||||||||||||||||||||||||||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||||||||||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||||||||||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||
* you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||
* You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||||||||||
* http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||||||||||
* Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||
* distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||
* See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||
* limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
#include "luci/Pass/FuseMulToFullyConnectedWeightsPass.h" | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
#include <luci/IR/CircleNodes.h> | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
#include <luci/test/TestIOGraph.h> | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
#include <gtest/gtest.h> | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
namespace | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
using namespace luci::test; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
template <loco::DataType DT> class FuseMulToFullyConnectedWeightsPassTestGraph : public TestIOGraph | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
public: | ||||||||||||||||||||||||||||||||||||||||||
FuseMulToFullyConnectedWeightsPassTestGraph() = default; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
void init(void) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
TestIOGraph::init({3, 4}, {3, 6}); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_mul = g()->nodes()->create<luci::CircleMul>(); | ||||||||||||||||||||||||||||||||||||||||||
_mul_s = g()->nodes()->create<luci::CircleConst>(); | ||||||||||||||||||||||||||||||||||||||||||
_fc = g()->nodes()->create<luci::CircleFullyConnected>(); | ||||||||||||||||||||||||||||||||||||||||||
_fc_w = g()->nodes()->create<luci::CircleConst>(); | ||||||||||||||||||||||||||||||||||||||||||
_fc_b = g()->nodes()->create<luci::CircleConst>(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_mul->name("mul"); | ||||||||||||||||||||||||||||||||||||||||||
_mul_s->name("mul_s"); | ||||||||||||||||||||||||||||||||||||||||||
_fc->name("fc"); | ||||||||||||||||||||||||||||||||||||||||||
_fc_w->name("fc_w"); | ||||||||||||||||||||||||||||||||||||||||||
_fc_b->name("fc_b"); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_mul->dtype(DT); | ||||||||||||||||||||||||||||||||||||||||||
_fc->dtype(DT); | ||||||||||||||||||||||||||||||||||||||||||
_mul->fusedActivationFunction(luci::FusedActFunc::NONE); | ||||||||||||||||||||||||||||||||||||||||||
_fc->fusedActivationFunction(luci::FusedActFunc::NONE); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_mul_s->rank(1); | ||||||||||||||||||||||||||||||||||||||||||
_mul_s->dim(0) = 3; | ||||||||||||||||||||||||||||||||||||||||||
_mul_s->dtype(DT); | ||||||||||||||||||||||||||||||||||||||||||
_mul_s->size<DT>(3); | ||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < 3; ++i) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_mul_s->at<DT>(0) = 1.0f; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_fc_w->rank(2); | ||||||||||||||||||||||||||||||||||||||||||
_fc_w->dim(0) = 3; | ||||||||||||||||||||||||||||||||||||||||||
_fc_w->dim(1) = 4; | ||||||||||||||||||||||||||||||||||||||||||
_fc_w->dtype(DT); | ||||||||||||||||||||||||||||||||||||||||||
_fc_w->size<DT>(4 * 6); | ||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < 4 * 6; ++i) | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+66
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The shape of
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please check about FullyConnected Op. with torch, import torch
tensor1 = torch.randn(3, 4)
FC = torch.nn.Linear(4, 6)
output = FC(tensor1)
print(output)
print(output.shape) gives something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. more simple case is with vector, tensor1 = torch.randn(4)
FC = torch.nn.Linear(4, 6)
output = FC(tensor1)
print(output)
print(output.shape) gives
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, should be
|
||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_fc_w->at<DT>(0) = 1.0f; | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto..? |
||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_fc_b->rank(1); | ||||||||||||||||||||||||||||||||||||||||||
_fc_b->dim(0) = 6; | ||||||||||||||||||||||||||||||||||||||||||
_fc_b->dtype(DT); | ||||||||||||||||||||||||||||||||||||||||||
_fc_b->size<DT>(6); | ||||||||||||||||||||||||||||||||||||||||||
for (uint32_t i = 0; i < 6; ++i) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_fc_b->at<DT>(0) = 1.0f; | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. diditto..? |
||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_mul->x(input()); | ||||||||||||||||||||||||||||||||||||||||||
_mul->y(_mul_s); | ||||||||||||||||||||||||||||||||||||||||||
_fc->input(_mul); | ||||||||||||||||||||||||||||||||||||||||||
_fc->weights(_fc_b); | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be |
||||||||||||||||||||||||||||||||||||||||||
_fc->bias(_fc_b); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
output()->from(_fc); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
luci::CircleMul *_mul = nullptr; | ||||||||||||||||||||||||||||||||||||||||||
luci::CircleFullyConnected *_fc = nullptr; | ||||||||||||||||||||||||||||||||||||||||||
luci::CircleConst *_mul_s = nullptr; | ||||||||||||||||||||||||||||||||||||||||||
luci::CircleConst *_fc_w = nullptr; | ||||||||||||||||||||||||||||||||||||||||||
luci::CircleConst *_fc_b = nullptr; | ||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
class FuseMulToFullyConnectedWeightsPassTest : public ::testing::Test | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
public: | ||||||||||||||||||||||||||||||||||||||||||
FuseMulToFullyConnectedWeightsPassTest() = default; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
protected: | ||||||||||||||||||||||||||||||||||||||||||
FuseMulToFullyConnectedWeightsPassTestGraph<loco::DataType::FLOAT32> _graph; | ||||||||||||||||||||||||||||||||||||||||||
luci::FuseMulToFullyConnectedWeightsPass _pass; | ||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
class FuseMulToFullyConnectedWeightsPassS32Test : public ::testing::Test | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
public: | ||||||||||||||||||||||||||||||||||||||||||
FuseMulToFullyConnectedWeightsPassS32Test() = default; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
protected: | ||||||||||||||||||||||||||||||||||||||||||
FuseMulToFullyConnectedWeightsPassTestGraph<loco::DataType::S32> _graph; | ||||||||||||||||||||||||||||||||||||||||||
luci::FuseMulToFullyConnectedWeightsPass _pass; | ||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
} // namespace | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
TEST_F(FuseMulToFullyConnectedWeightsPassTest, name) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
auto const name = _pass.name(); | ||||||||||||||||||||||||||||||||||||||||||
ASSERT_NE(nullptr, name); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
TEST_F(FuseMulToFullyConnectedWeightsPassTest, fuse_mul_to_fc_weights) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_graph.init(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
EXPECT_TRUE(_pass.run(_graph.g())); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
TEST_F(FuseMulToFullyConnectedWeightsPassTest, mul_fused_act_NEG) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_graph.init(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_graph._mul->fusedActivationFunction(luci::FusedActFunc::RELU); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
EXPECT_FALSE(_pass.run(_graph.g())); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
TEST_F(FuseMulToFullyConnectedWeightsPassTest, mul_d2_NEG) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_graph.init(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
_graph._mul_s->rank(2); | ||||||||||||||||||||||||||||||||||||||||||
_graph._mul_s->dim(0) = 1; | ||||||||||||||||||||||||||||||||||||||||||
_graph._mul_s->dim(1) = 3; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
EXPECT_FALSE(_pass.run(_graph.g())); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
TEST_F(FuseMulToFullyConnectedWeightsPassS32Test, dtype_s32_NEG) | ||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||
_graph.init(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
EXPECT_FALSE(_pass.run(_graph.g())); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+155
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test seems same with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, it was |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i
index is not used infor
statement..!?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you intend below code?