Skip to content

Commit

Permalink
[draft] Constant folding for Shape op
Browse files Browse the repository at this point in the history
This draft folds the Shape op as a constant.

Signed-off-by: jihunnn-kim <[email protected]>
  • Loading branch information
miusic committed Jan 10, 2024
1 parent 6bed623 commit ff6d25c
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 0 deletions.
3 changes: 3 additions & 0 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ int entry(int argc, char **argv)
add_switch(arser, "--fold_fully_connected",
"This will fold FullyConnected operator with constant inputs");
add_switch(arser, "--fold_gather", "This will fold Gather operator");
add_switch(arser, "--fold_shape", "This will fold Shape operator");
add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator");
add_switch(arser, "--forward_reshape_to_unaryop",
"This will move Reshape after UnaryOp for centain condition");
Expand Down Expand Up @@ -263,6 +264,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FoldFullyConnected);
if (arser.get<bool>("--fold_gather"))
options->enable(Algorithms::FoldGather);
if (arser.get<bool>("--fold_shape"))
options->enable(Algorithms::FoldShape);
if (arser.get<bool>("--fold_sparse_to_dense"))
options->enable(Algorithms::FoldSparseToDense);
if (arser.get<bool>("--forward_reshape_to_unaryop"))
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CircleOptimizer final
FoldFullyConnected,
FoldDequantize,
FoldGather,
FoldShape,
FoldSparseToDense,
ForwardReshapeToUnaryOp,
ForwardTransposeOp,
Expand Down
38 changes: 38 additions & 0 deletions compiler/luci/pass/include/luci/Pass/FoldShapePass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FOLD_SHAPE_PASS_H__
#define __LUCI_FOLD_SHAPE_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fold Shape to a constant tensor
*
*/
struct FoldShapePass final : public logo::Pass
{
const char *name(void) const final { return "luci::FoldShapePass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FOLD_SHAPE_PASS_H__
5 changes: 5 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "luci/Pass/FoldDequantizePass.h"
#include "luci/Pass/FoldFullyConnectedPass.h"
#include "luci/Pass/FoldGatherPass.h"
#include "luci/Pass/FoldShapePass.h"
#include "luci/Pass/FoldSparseToDensePass.h"
#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
#include "luci/Pass/ForwardTransposeOpPass.h"
Expand Down Expand Up @@ -365,6 +366,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
}
if (_options->query(Options::Algorithm::FoldShape))
{
phase.emplace_back(std::make_unique<luci::FoldShapePass>());
}
if (_options->query(Options::Algorithm::FoldSparseToDense))
{
phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
Expand Down
84 changes: 84 additions & 0 deletions compiler/luci/pass/src/FoldShapePass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FoldShapePass.h"

#include <luci/IR/CircleNodes.h>

namespace
{

luci::CircleConst *shape_of(luci::CircleNode *input_tensor, loco::DataType out_dtype)
{
auto name = input_tensor->name();
auto shape_status = input_tensor->shape_status();
auto rank = input_tensor->rank();
assert(name.length() > 0);
assert(shape_status == luci::ShapeStatus::VALID);
assert(rank > 0);

auto shape = input_tensor->graph()->nodes()->create<luci::CircleConst>();
shape->name(name + "_Folded");
shape->dtype(out_dtype);
shape->rank(1);
shape->dim(0).set(rank);

return shape;
}

/**
* Fold Shape to const if the input shape is fixed
**/
bool fold_shape(luci::CircleShape *shape)
{
auto input_tensor = dynamic_cast<luci::CircleNode *>(shape->input());
auto out_dtype = shape->out_type();

auto tensor_shape = shape_of(input_tensor, out_dtype);
if (not tensor_shape)
return false;

loco::replace(shape).with(tensor_shape);

return true;
}

} // namespace

namespace luci
{

/**
* Constant Folding for Shape Op
**/
bool FoldShapePass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
if (auto shape = dynamic_cast<luci::CircleShape *>(node))
{
if (fold_shape(shape))
{
changed = true;
}
}
}

return changed;
}

} // namespace luci
106 changes: 106 additions & 0 deletions compiler/luci/pass/src/FoldShapePass.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FoldShapePass.h"
#include "PassTestGraphs.h"

#include <luci/IR/CircleNodes.h>

#include <gtest/gtest.h>

namespace
{

template <loco::DataType FromT, loco::DataType ToT>
class FoldShapeGraph : public luci::ConstantFoldingAddTestGraph
{
public:
FoldShapeGraph(std::vector<uint32_t> input_shape)
: luci::ConstantFoldingAddTestGraph(input_shape, ToT)
{
_x = _g.nodes()->template create<luci::CircleConst>();
_x->name("x");
_x->dtype(FromT);
_x->rank(input_shape.size());
_x->shape_status(luci::ShapeStatus::VALID);

_shape = _g.nodes()->template create<luci::CircleShape>();
_shape->name("shape");
_shape->out_type(ToT);
_shape->input(_x);
_shape->shape({4});
_shape->rank(1);
for (int i = 0; i < _shape->rank(); ++i)
{
_shape->dim(i).set(_x->dim(i).value());
}
}

loco::Node *createFoldedPattern() override { return _shape; }

protected:
luci::CircleShape *_shape = nullptr;
luci::CircleConst *_x = nullptr;
};

/**
* Graph that has a Shape Op
*
* BEFORE
*
* [Input]
* |
* [Shape]
* |
* [Output]
*
* AFTER
*
* [CircleConst]
*
*/
class FoldShapeGraphTest : public FoldShapeGraph<loco::DataType::S64, loco::DataType::S32>,
public ::testing::Test
{
public:
FoldShapeGraphTest() : FoldShapeGraph<loco::DataType::S64, loco::DataType::S32>({1, 8, 8, 64}) {}

virtual void SetUp() { init(); }
};

} // namespace

TEST(FoldShapePassTest, name)
{
luci::FoldShapePass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(FoldShapeGraphTest, fold_shape)
{
luci::FoldShapePass pass;
while (pass.run(graph()))
;

auto folded_const = getFoldedPattern();
EXPECT_NE(nullptr, folded_const);

// Check type, shape, values of folded shape
EXPECT_EQ(loco::DataType::S32, folded_const->dtype());
EXPECT_EQ(1, folded_const->rank());
EXPECT_EQ(1, folded_const->dim(0).value());
}
1 change: 1 addition & 0 deletions compiler/one-cmds/how-to-use-one-commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Current transformation options are
- fold_dequantize : This removes Dequantize operation which can be folded
- fold_dwconv : This folds Depthwise Convolution operation which can be folded
- fold_gather : This removes Gather operation which can be folded
- fold_shape : This removes Shape operation which can be folded
- fold_sparse_to_dense : This removes SparseToDense operation which can be folded
- forward_reshape_to_unaryop: This will move Reshape after UnaryOp for centain condition
- fuse_add_with_fully_connected: This fuses Add operator with the preceding FullyConnected operator if possible
Expand Down
1 change: 1 addition & 0 deletions compiler/one-cmds/onelib/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class CONSTANT:
'fold_dwconv',
'fold_fully_connected',
'fold_gather',
'fold_shape',
'fold_sparse_to_dense',

# Operator fusion
Expand Down

0 comments on commit ff6d25c

Please sign in to comment.