diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index 2775aafe7d2..6699b60b504 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -81,6 +81,7 @@ int entry(int argc, char **argv) add_switch(arser, "--fold_fully_connected", "This will fold FullyConnected operator with constant inputs"); add_switch(arser, "--fold_gather", "This will fold Gather operator"); + add_switch(arser, "--fold_shape", "This will fold Shape operator"); add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator"); add_switch(arser, "--forward_reshape_to_unaryop", "This will move Reshape after UnaryOp for centain condition"); @@ -263,6 +264,8 @@ int entry(int argc, char **argv) options->enable(Algorithms::FoldFullyConnected); if (arser.get("--fold_gather")) options->enable(Algorithms::FoldGather); + if (arser.get("--fold_shape")) + options->enable(Algorithms::FoldShape); if (arser.get("--fold_sparse_to_dense")) options->enable(Algorithms::FoldSparseToDense); if (arser.get("--forward_reshape_to_unaryop")) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index b6285812225..a017ad8b49a 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -60,6 +60,7 @@ class CircleOptimizer final FoldFullyConnected, FoldDequantize, FoldGather, + FoldShape, FoldSparseToDense, ForwardReshapeToUnaryOp, ForwardTransposeOp, diff --git a/compiler/luci/pass/include/luci/Pass/FoldShapePass.h b/compiler/luci/pass/include/luci/Pass/FoldShapePass.h new file mode 100644 index 00000000000..89ed410dc07 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldShapePass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_SHAPE_PASS_H__ +#define __LUCI_FOLD_SHAPE_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fold Shape to a constant tensor + * + */ +struct FoldShapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldShapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_SHAPE_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 53983fcabe0..4aedbc3a683 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -26,6 +26,7 @@ #include "luci/Pass/FoldDequantizePass.h" #include "luci/Pass/FoldFullyConnectedPass.h" #include "luci/Pass/FoldGatherPass.h" +#include "luci/Pass/FoldShapePass.h" #include "luci/Pass/FoldSparseToDensePass.h" #include "luci/Pass/ForwardReshapeToUnaryOpPass.h" #include "luci/Pass/ForwardTransposeOpPass.h" @@ -365,6 +366,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FoldShape)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FoldSparseToDense)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FoldShapePass.cpp b/compiler/luci/pass/src/FoldShapePass.cpp new file mode 100644 index 00000000000..ab0e840a744 --- /dev/null +++ b/compiler/luci/pass/src/FoldShapePass.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldShapePass.h" + +#include + +namespace +{ + +luci::CircleConst *shape_of(luci::CircleNode *input_tensor, loco::DataType out_dtype) +{ + auto name = input_tensor->name(); + auto shape_status = input_tensor->shape_status(); + auto rank = input_tensor->rank(); + assert(name.length() > 0); + assert(shape_status == luci::ShapeStatus::VALID); + assert(rank > 0); + + auto shape = input_tensor->graph()->nodes()->create(); + shape->name(name + "_Folded"); + shape->dtype(out_dtype); + shape->rank(1); + shape->dim(0).set(rank); + + return shape; +} + +/** + * Fold Shape to const if the input shape is fixed + **/ +bool fold_shape(luci::CircleShape *shape) +{ + auto input_tensor = dynamic_cast(shape->input()); + auto out_dtype = shape->out_type(); + + auto tensor_shape = shape_of(input_tensor, out_dtype); + if (not tensor_shape) + return false; + + loco::replace(shape).with(tensor_shape); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for Shape Op + **/ +bool FoldShapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto shape = dynamic_cast(node)) + { + if (fold_shape(shape)) + { + changed = true; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldShapePass.test.cpp b/compiler/luci/pass/src/FoldShapePass.test.cpp new file mode 100644 index 00000000000..77e15d70d6c --- /dev/null +++ b/compiler/luci/pass/src/FoldShapePass.test.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldShapePass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +template +class FoldShapeGraph : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldShapeGraph(std::vector input_shape) + : luci::ConstantFoldingAddTestGraph(input_shape, ToT) + { + _x = _g.nodes()->template create(); + _x->name("x"); + _x->dtype(FromT); + _x->rank(input_shape.size()); + _x->shape_status(luci::ShapeStatus::VALID); + + _shape = _g.nodes()->template create(); + _shape->name("shape"); + _shape->out_type(ToT); + _shape->input(_x); + _shape->shape({4}); + _shape->rank(1); + for (int i = 0; i < _shape->rank(); ++i) + { + _shape->dim(i).set(_x->dim(i).value()); + } + } + + loco::Node *createFoldedPattern() override { return _shape; } + +protected: + luci::CircleShape *_shape = nullptr; + luci::CircleConst *_x = nullptr; +}; + +/** + * Graph that has a Shape Op + * + * BEFORE + * + * [Input] + * | + * [Shape] + * | + * [Output] + * + * AFTER + * + * [CircleConst] + * + */ +class FoldShapeGraphTest : public FoldShapeGraph, + public ::testing::Test +{ +public: + FoldShapeGraphTest() : FoldShapeGraph({1, 8, 8, 64}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST(FoldShapePassTest, name) +{ + luci::FoldShapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FoldShapeGraphTest, fold_shape) +{ + luci::FoldShapePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded shape + EXPECT_EQ(loco::DataType::S32, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(1, folded_const->dim(0).value()); +} diff --git a/compiler/one-cmds/how-to-use-one-commands.txt b/compiler/one-cmds/how-to-use-one-commands.txt index 7683d06dc52..343ec59064f 100644 --- a/compiler/one-cmds/how-to-use-one-commands.txt +++ b/compiler/one-cmds/how-to-use-one-commands.txt @@ -160,6 +160,7 @@ Current transformation options are - fold_dequantize : This removes Dequantize operation which can be folded - fold_dwconv : This folds Depthwise Convolution operation which can be folded - fold_gather : This removes Gather operation which can be folded +- fold_shape : This removes Shape operation which can be folded - fold_sparse_to_dense : This removes SparseToDense operation which can be folded - forward_reshape_to_unaryop: This will move Reshape after UnaryOp for centain condition - fuse_add_with_fully_connected: This fuses Add operator with the preceding FullyConnected operator if possible diff --git a/compiler/one-cmds/onelib/constant.py b/compiler/one-cmds/onelib/constant.py index d540ecc79ec..bfac5ccf6d9 100644 --- a/compiler/one-cmds/onelib/constant.py +++ b/compiler/one-cmds/onelib/constant.py @@ -29,6 +29,7 @@ class CONSTANT: 'fold_dwconv', 'fold_fully_connected', 'fold_gather', + 'fold_shape', 'fold_sparse_to_dense', # Operator fusion