Skip to content

Commit

Permalink
[luci/pass] Introduce FoldMulPass (#13440)
Browse files Browse the repository at this point in the history
This will introduce FoldMulPass which will fold Mul to Constant if possible.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark authored Jul 16, 2024
1 parent 7de74b8 commit ab68724
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class CircleOptimizer final
FoldFullyConnected,
FoldDequantize,
FoldGather,
FoldMul,
FoldReshape,
FoldShape,
FoldSparseToDense,
Expand Down
38 changes: 38 additions & 0 deletions compiler/luci/pass/include/luci/Pass/FoldMulPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FOLD_MUL_PASS_H__
#define __LUCI_FOLD_MUL_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fold Mul to a constant tensor
*
*/
struct FoldMulPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FoldMulPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FOLD_MUL_PASS_H__
5 changes: 5 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "luci/Pass/FoldDequantizePass.h"
#include "luci/Pass/FoldFullyConnectedPass.h"
#include "luci/Pass/FoldGatherPass.h"
#include "luci/Pass/FoldMulPass.h"
#include "luci/Pass/FoldReshapePass.h"
#include "luci/Pass/FoldShapePass.h"
#include "luci/Pass/FoldSparseToDensePass.h"
Expand Down Expand Up @@ -391,6 +392,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
}
if (_options->query(Options::Algorithm::FoldMul))
{
phase.emplace_back(std::make_unique<luci::FoldMulPass>());
}
if (_options->query(Options::Algorithm::FoldReshape))
{
phase.emplace_back(std::make_unique<luci::FoldReshapePass>());
Expand Down
127 changes: 127 additions & 0 deletions compiler/luci/pass/src/FoldMulPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FoldMulPass.h"

#include <luci/IR/CircleNodes.h>

#include <algorithm>

#define CHECK_OR_FALSE(condition) \
if (not(condition)) \
return false;

namespace
{

/**
* @return higher rank of x, y or nullptr if not compatible
*/
const luci::CircleConst *compatible_shape(const luci::CircleConst *x, const luci::CircleConst *y)
{
if (x->rank() >= y->rank())
{
uint32_t d = x->rank() - y->rank();
for (uint32_t i = 0; i < y->rank(); i++)
{
// NOTE dim() has only '==' operator
if (!(x->dim(i + d) == y->dim(i)))
return nullptr;
}
return x;
}
else
{
uint32_t d = y->rank() - x->rank();
for (uint32_t i = 0; i < x->rank(); i++)
{
if (!(x->dim(i) == y->dim(i + d)))
return nullptr;
}
return y;
}
}

/**
* Fold Mul to const if both inputs are const
**/
bool fold_mul(luci::CircleMul *mul)
{
CHECK_OR_FALSE(mul);
CHECK_OR_FALSE(mul->dtype() == loco::DataType::FLOAT32);

// Check inputs are const and compatible
auto x = dynamic_cast<luci::CircleConst *>(mul->x());
auto y = dynamic_cast<luci::CircleConst *>(mul->y());
CHECK_OR_FALSE(x);
CHECK_OR_FALSE(y);
CHECK_OR_FALSE(x->dtype() == y->dtype());
const auto xy = compatible_shape(x, y);
CHECK_OR_FALSE(xy);

auto name_x = x->name();
auto name_y = y->name();
assert(name_x.length() > 0);
assert(name_y.length() > 0);
auto folded_const = mul->graph()->nodes()->create<luci::CircleConst>();
folded_const->dtype(xy->dtype());
folded_const->rank(xy->rank());
for (uint32_t i = 0; i < xy->rank(); i++)
folded_const->dim(i).set(xy->dim(i).value());

const auto size_x = x->size<loco::DataType::FLOAT32>();
const auto size_y = y->size<loco::DataType::FLOAT32>();
const auto size_xy = xy->size<loco::DataType::FLOAT32>();
folded_const->size<loco::DataType::FLOAT32>(size_xy);
for (uint32_t i = 0; i < size_xy; i++)
{
auto xv = x->at<loco::DataType::FLOAT32>(i % size_x);
auto yv = y->at<loco::DataType::FLOAT32>(i % size_y);
folded_const->at<loco::DataType::FLOAT32>(i) = xv * yv;
}

folded_const->shape_status(luci::ShapeStatus::VALID);
folded_const->name(name_x + "_" + name_y);

loco::replace(mul).with(folded_const);

return true;
}

} // namespace

namespace luci
{

/**
* Constant Folding for Mul Op
**/
bool FoldMulPass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
if (auto mul = dynamic_cast<luci::CircleMul *>(node))
{
if (fold_mul(mul))
changed = true;
}
}

return changed;
}

} // namespace luci
133 changes: 133 additions & 0 deletions compiler/luci/pass/src/FoldMulPass.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FoldMulPass.h"
#include "PassTestGraphs.h"

#include <luci/IR/CircleNodes.h>

#include <gtest/gtest.h>

namespace
{

/**
* Graph has an Mul Op with constant inputs
*
* BEFORE
*
* [CircleConst] [CircleConst]
* | |
* [CircleMul]
* |
* [CircleNode]
* AFTER
* [CircleConst] [CircleConst]
* | |
* [CircleConst] [CircleMul]
* |
* [CircleNode]
*/

template <loco::DataType T> class FoldMulTest : public luci::ConstantFoldingAddTestGraph
{
public:
FoldMulTest(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T)
{
_mul = _g.nodes()->template create<luci::CircleMul>();
_x = _g.nodes()->template create<luci::CircleConst>();
_y = _g.nodes()->template create<luci::CircleConst>();

_mul->dtype(T);
_x->dtype(T);
_y->dtype(T);

_mul->shape(shape);
_x->shape(shape);
_y->shape(shape);

uint32_t num_elems = 1;
for (auto dim = shape.begin(); dim != shape.end(); dim++)
num_elems *= *dim;

_x->size<T>(num_elems);
_y->size<T>(num_elems);

for (uint32_t i = 0; i < num_elems; i++)
{
_x->at<T>(i) = i + 1;
_y->at<T>(i) = i + 1;
}

_mul->x(_x);
_mul->y(_y);
_mul->name("mul");
_x->name("x");
_y->name("y");
}

loco::Node *createFoldedPattern() override { return _mul; }

virtual ~FoldMulTest() = default;

protected:
luci::CircleMul *_mul = nullptr;
luci::CircleConst *_x = nullptr;
luci::CircleConst *_y = nullptr;
};

class FoldF32MulTest : public FoldMulTest<loco::DataType::FLOAT32>, public ::testing::Test
{
public:
FoldF32MulTest() : FoldMulTest<loco::DataType::FLOAT32>({3}) {}

virtual void SetUp() { init(); }
};

} // namespace

TEST_F(FoldF32MulTest, name)
{
luci::FoldMulPass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(FoldF32MulTest, fold_mul)
{
luci::FoldMulPass pass;
while (pass.run(graph()))
;

auto folded_const = getFoldedPattern();
EXPECT_NE(nullptr, folded_const);

// Check type, shape, values of folded const
EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
EXPECT_EQ(1, folded_const->rank());
EXPECT_EQ(3, folded_const->dim(0).value());
EXPECT_EQ(1, folded_const->at<loco::DataType::FLOAT32>(0));
EXPECT_EQ(4, folded_const->at<loco::DataType::FLOAT32>(1));
EXPECT_EQ(9, folded_const->at<loco::DataType::FLOAT32>(2));
}

TEST_F(FoldF32MulTest, input_type_mismatch_NEG)
{
_x->dtype(loco::DataType::U4);

luci::FoldMulPass pass;
EXPECT_FALSE(pass.run(graph()));
}

0 comments on commit ab68724

Please sign in to comment.