-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[luci/pass] Introduce FoldMulPass (#13440)
This will introduce FoldMulPass which will fold Mul to Constant if possible. ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
- Loading branch information
1 parent
7de74b8
commit ab68724
Showing
5 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __LUCI_FOLD_MUL_PASS_H__ | ||
#define __LUCI_FOLD_MUL_PASS_H__ | ||
|
||
#include <logo/Pass.h> | ||
|
||
namespace luci | ||
{ | ||
|
||
/** | ||
* @brief Class to fold Mul to a constant tensor | ||
* | ||
*/ | ||
struct FoldMulPass final : public logo::Pass | ||
{ | ||
const char *name(void) const final { return "luci::FoldMulPass"; } | ||
|
||
bool run(loco::Graph *g) final; | ||
}; | ||
|
||
} // namespace luci | ||
|
||
#endif // __LUCI_FOLD_MUL_PASS_H__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/FoldMulPass.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
|
||
#include <algorithm> | ||
|
||
#define CHECK_OR_FALSE(condition) \ | ||
if (not(condition)) \ | ||
return false; | ||
|
||
namespace | ||
{ | ||
|
||
/** | ||
* @return higher rank of x, y or nullptr if not compatible | ||
*/ | ||
const luci::CircleConst *compatible_shape(const luci::CircleConst *x, const luci::CircleConst *y) | ||
{ | ||
if (x->rank() >= y->rank()) | ||
{ | ||
uint32_t d = x->rank() - y->rank(); | ||
for (uint32_t i = 0; i < y->rank(); i++) | ||
{ | ||
// NOTE dim() has only '==' operator | ||
if (!(x->dim(i + d) == y->dim(i))) | ||
return nullptr; | ||
} | ||
return x; | ||
} | ||
else | ||
{ | ||
uint32_t d = y->rank() - x->rank(); | ||
for (uint32_t i = 0; i < x->rank(); i++) | ||
{ | ||
if (!(x->dim(i) == y->dim(i + d))) | ||
return nullptr; | ||
} | ||
return y; | ||
} | ||
} | ||
|
||
/** | ||
* Fold Mul to const if both inputs are const | ||
**/ | ||
bool fold_mul(luci::CircleMul *mul) | ||
{ | ||
CHECK_OR_FALSE(mul); | ||
CHECK_OR_FALSE(mul->dtype() == loco::DataType::FLOAT32); | ||
|
||
// Check inputs are const and compatible | ||
auto x = dynamic_cast<luci::CircleConst *>(mul->x()); | ||
auto y = dynamic_cast<luci::CircleConst *>(mul->y()); | ||
CHECK_OR_FALSE(x); | ||
CHECK_OR_FALSE(y); | ||
CHECK_OR_FALSE(x->dtype() == y->dtype()); | ||
const auto xy = compatible_shape(x, y); | ||
CHECK_OR_FALSE(xy); | ||
|
||
auto name_x = x->name(); | ||
auto name_y = y->name(); | ||
assert(name_x.length() > 0); | ||
assert(name_y.length() > 0); | ||
auto folded_const = mul->graph()->nodes()->create<luci::CircleConst>(); | ||
folded_const->dtype(xy->dtype()); | ||
folded_const->rank(xy->rank()); | ||
for (uint32_t i = 0; i < xy->rank(); i++) | ||
folded_const->dim(i).set(xy->dim(i).value()); | ||
|
||
const auto size_x = x->size<loco::DataType::FLOAT32>(); | ||
const auto size_y = y->size<loco::DataType::FLOAT32>(); | ||
const auto size_xy = xy->size<loco::DataType::FLOAT32>(); | ||
folded_const->size<loco::DataType::FLOAT32>(size_xy); | ||
for (uint32_t i = 0; i < size_xy; i++) | ||
{ | ||
auto xv = x->at<loco::DataType::FLOAT32>(i % size_x); | ||
auto yv = y->at<loco::DataType::FLOAT32>(i % size_y); | ||
folded_const->at<loco::DataType::FLOAT32>(i) = xv * yv; | ||
} | ||
|
||
folded_const->shape_status(luci::ShapeStatus::VALID); | ||
folded_const->name(name_x + "_" + name_y); | ||
|
||
loco::replace(mul).with(folded_const); | ||
|
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
namespace luci | ||
{ | ||
|
||
/** | ||
* Constant Folding for Mul Op | ||
**/ | ||
bool FoldMulPass::run(loco::Graph *g) | ||
{ | ||
bool changed = false; | ||
for (auto node : loco::active_nodes(loco::output_nodes(g))) | ||
{ | ||
if (auto mul = dynamic_cast<luci::CircleMul *>(node)) | ||
{ | ||
if (fold_mul(mul)) | ||
changed = true; | ||
} | ||
} | ||
|
||
return changed; | ||
} | ||
|
||
} // namespace luci |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/FoldMulPass.h" | ||
#include "PassTestGraphs.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace | ||
{ | ||
|
||
/** | ||
* Graph has an Mul Op with constant inputs | ||
* | ||
* BEFORE | ||
* | ||
* [CircleConst] [CircleConst] | ||
* | | | ||
* [CircleMul] | ||
* | | ||
* [CircleNode] | ||
* AFTER | ||
* [CircleConst] [CircleConst] | ||
* | | | ||
* [CircleConst] [CircleMul] | ||
* | | ||
* [CircleNode] | ||
*/ | ||
|
||
template <loco::DataType T> class FoldMulTest : public luci::ConstantFoldingAddTestGraph | ||
{ | ||
public: | ||
FoldMulTest(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T) | ||
{ | ||
_mul = _g.nodes()->template create<luci::CircleMul>(); | ||
_x = _g.nodes()->template create<luci::CircleConst>(); | ||
_y = _g.nodes()->template create<luci::CircleConst>(); | ||
|
||
_mul->dtype(T); | ||
_x->dtype(T); | ||
_y->dtype(T); | ||
|
||
_mul->shape(shape); | ||
_x->shape(shape); | ||
_y->shape(shape); | ||
|
||
uint32_t num_elems = 1; | ||
for (auto dim = shape.begin(); dim != shape.end(); dim++) | ||
num_elems *= *dim; | ||
|
||
_x->size<T>(num_elems); | ||
_y->size<T>(num_elems); | ||
|
||
for (uint32_t i = 0; i < num_elems; i++) | ||
{ | ||
_x->at<T>(i) = i + 1; | ||
_y->at<T>(i) = i + 1; | ||
} | ||
|
||
_mul->x(_x); | ||
_mul->y(_y); | ||
_mul->name("mul"); | ||
_x->name("x"); | ||
_y->name("y"); | ||
} | ||
|
||
loco::Node *createFoldedPattern() override { return _mul; } | ||
|
||
virtual ~FoldMulTest() = default; | ||
|
||
protected: | ||
luci::CircleMul *_mul = nullptr; | ||
luci::CircleConst *_x = nullptr; | ||
luci::CircleConst *_y = nullptr; | ||
}; | ||
|
||
class FoldF32MulTest : public FoldMulTest<loco::DataType::FLOAT32>, public ::testing::Test | ||
{ | ||
public: | ||
FoldF32MulTest() : FoldMulTest<loco::DataType::FLOAT32>({3}) {} | ||
|
||
virtual void SetUp() { init(); } | ||
}; | ||
|
||
} // namespace | ||
|
||
TEST_F(FoldF32MulTest, name) | ||
{ | ||
luci::FoldMulPass pass; | ||
auto const name = pass.name(); | ||
ASSERT_NE(nullptr, name); | ||
} | ||
|
||
TEST_F(FoldF32MulTest, fold_mul) | ||
{ | ||
luci::FoldMulPass pass; | ||
while (pass.run(graph())) | ||
; | ||
|
||
auto folded_const = getFoldedPattern(); | ||
EXPECT_NE(nullptr, folded_const); | ||
|
||
// Check type, shape, values of folded const | ||
EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); | ||
EXPECT_EQ(1, folded_const->rank()); | ||
EXPECT_EQ(3, folded_const->dim(0).value()); | ||
EXPECT_EQ(1, folded_const->at<loco::DataType::FLOAT32>(0)); | ||
EXPECT_EQ(4, folded_const->at<loco::DataType::FLOAT32>(1)); | ||
EXPECT_EQ(9, folded_const->at<loco::DataType::FLOAT32>(2)); | ||
} | ||
|
||
TEST_F(FoldF32MulTest, input_type_mismatch_NEG) | ||
{ | ||
_x->dtype(loco::DataType::U4); | ||
|
||
luci::FoldMulPass pass; | ||
EXPECT_FALSE(pass.run(graph())); | ||
} |