Skip to content

Commit

Permalink
[luci/pass] Refactor FuseAddWithTConvPass (#13745)
Browse files Browse the repository at this point in the history
This commit changes the order of searching for the pattern and adds tests for the pass.

ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz <[email protected]>
  • Loading branch information
jiwaszki authored Aug 28, 2024
1 parent 96039fc commit 543d655
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 53 deletions.
93 changes: 41 additions & 52 deletions compiler/luci/pass/src/FuseAddWithTConvPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@

#include "luci/Pass/FuseAddWithTConvPass.h"

#include "helpers/NodeFiller.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>

namespace
{

#define RETURN_FALSE_UNLESS(cond) \
if (not(cond)) \
return false;

/**
* Fuse Add to TransposeConv if possible
*
Expand All @@ -42,89 +49,74 @@ namespace
*
* Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
bool fuse_add_with_tconv(luci::CircleAdd *add)
{
// skip if tconv has fused activation
if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;
// check whether it has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
if (not bias)
return false;

// get weight of tconv
auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
if (not filter)
return false;
if (filter->dtype() != loco::DataType::FLOAT32)
return false;

// get add node
auto tconv_output = loco::succs(tconv);
assert(tconv_output.size() == 1);
auto add = dynamic_cast<luci::CircleAdd *>(*tconv_output.begin());
if (not add)
return false;
if (add->dtype() != loco::DataType::FLOAT32)
return false;
if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
add->fusedActivationFunction() != luci::FusedActFunc::RELU)
return false;

// get addition
// Allow Add node only with FLOAT32 data type.
RETURN_FALSE_UNLESS(add->dtype() == loco::DataType::FLOAT32);
// Allow Add node only with specific activations.
RETURN_FALSE_UNLESS(add->fusedActivationFunction() == luci::FusedActFunc::NONE ||
add->fusedActivationFunction() == luci::FusedActFunc::RELU6 ||
add->fusedActivationFunction() == luci::FusedActFunc::RELU);
// Find the pattern of Add(TransposeConv, CircleConst):
luci::CircleTransposeConv *tconv = nullptr;
luci::CircleConst *addition = nullptr;
if (add->x() == tconv)
addition = dynamic_cast<luci::CircleConst *>(add->y());
else
addition = dynamic_cast<luci::CircleConst *>(add->x());
RETURN_FALSE_UNLESS(luci::fill(&tconv, &addition).with_commutative_args_of(add));

if (not addition)
return false;
RETURN_FALSE_UNLESS(loco::succs(tconv).size() == 1);

// Skip if tconv has fused activation.
RETURN_FALSE_UNLESS(tconv->fusedActivationFunction() == luci::FusedActFunc::NONE);
// Check whether tconv has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
RETURN_FALSE_UNLESS(bias);
// Get weights of tconv:
auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
RETURN_FALSE_UNLESS(filter);
RETURN_FALSE_UNLESS(filter->dtype() == loco::DataType::FLOAT32);

// addition dim(0) == tconv filter channel dim
if (addition->rank() != 1)
return false;
RETURN_FALSE_UNLESS(addition->rank() == 1);

auto addition_dim = addition->dim(0).value();
auto filter_channel_dim = filter->dim(0).value();
if (filter_channel_dim != addition_dim)
return false;
RETURN_FALSE_UNLESS(filter_channel_dim == addition_dim);

// fuse addition with transposed conv
// Fuse addition with transposed conv:
tconv->bias(addition);

if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
{
auto name = addition->name();
assert(name.length() > 0);
// separate relu op from add op
// Separate relu op from add op:
auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
relu->features(tconv);
relu->name(name + "/Relu6");
luci::add_origin(relu, luci::get_origin(add));

// remove add node
// Remove add node.
replace(add).with(relu);
}
else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU)
{
auto name = addition->name();
assert(name.length() > 0);
// separate relu op from add op
// Separate relu op from add op:
auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
relu->features(tconv);
relu->name(name + "/Relu");
luci::add_origin(relu, luci::get_origin(add));

// remove add node
// Remove add node.
replace(add).with(relu);
}
else
{
// Remove add node.
replace(add).with(tconv);
}

// set origin
// Set new origin.
luci::add_origin(tconv, luci::get_origin(add));

return true;
Expand All @@ -140,12 +132,9 @@ bool FuseAddWithTConvPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
if (not tconv)
continue;

if (fuse_add_with_tconv(tconv))
changed = true;
if (auto add = dynamic_cast<luci::CircleAdd *>(node))
if (fuse_add_with_tconv(add))
changed = true;
}

return changed;
Expand Down
189 changes: 188 additions & 1 deletion compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,196 @@

#include "luci/Pass/FuseAddWithTConvPass.h"

#include "helpers/CreateCircleConst.h"

#include <luci/IR/CircleNodes.h>
#include <luci/test/TestIOGraph.h>

#include <gtest/gtest.h>

TEST(FuseAddWithTConvPassTest, name)
#define ADD_VAL 5.0f
namespace
{

using namespace luci::test;

/**
* Graph for this test
*
* BEFORE (without extra_successor)
*
* |
* [CircleConst] [CircleTransposeConv]
* \ |
* [CircleAdd w/ Relu]
* |
*
* BEFORE (with extra_successor)
*
* |
* [CircleConst] [CircleTransposeConv]
* \ | |
* [CircleAdd w/ Relu] [extra FC]
* | |
*
* AFTER (if pass was successful)
*
* |
* [CircleConst as bias] |
* \ |
* [CircleTransposeConv]
* |
* ([CircleRelu/Relu])
* |
*
*/
class TConvAddGraphlet
{
public:
void init(loco::Graph *g, luci::FusedActFunc tconv_activation, bool use_bias,
bool extra_successor)
{
_tconv = g->nodes()->create<luci::CircleTransposeConv>();

std::vector<float> input_sizes_val = {1, 4, 4, 1};
_tconv_i = luci::create_const_node(g, loco::DataType::FLOAT32, {4}, input_sizes_val);
_tconv->inputSizes(_tconv_i);

std::vector<float> filter_val(18);
for (uint32_t i = 0; i < 18; i++)
filter_val.at(i) = i;

_tconv_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 3, 3, 2}, filter_val);
_tconv->filter(_tconv_f);

if (use_bias)
{
std::vector<float> bias_val(1, 3.0f);
_tconv_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, bias_val);
}
else
{
// Create CircleOutputExclude -- no bias
_tconv_b = g->nodes()->create<luci::CircleOutputExclude>();
}
_tconv->bias(_tconv_b);

_tconv->padding(luci::Padding::VALID);
auto _stride = _tconv->stride();
_stride->w(1);
_stride->h(1);
_tconv->fusedActivationFunction(tconv_activation);
_tconv->dtype(loco::DataType::FLOAT32);
_tconv->shape({1, 4, 4, 1});
_tconv->name("tconv");

if (extra_successor)
{
_extra_succ = g->nodes()->create<luci::CircleFullyConnected>();
// Set previous TConv as input to bump number of successors for it:
_extra_succ->input(_tconv);
std::vector<float> weights_val(8);
_extra_f = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 8}, weights_val);
_extra_succ->weights(_extra_f);
_extra_succ->bias(nullptr);
_extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE);
_extra_succ->dtype(loco::DataType::FLOAT32);
_extra_succ->shape({1, 4, 4, 1});
_extra_succ->name("extra_fc");
}

std::vector<float> add_values(1, ADD_VAL);
_add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, add_values);
_add_c->name("const_c");

_add = g->nodes()->create<luci::CircleAdd>();
_add->x(_tconv);
_add->y(_add_c);
_add->fusedActivationFunction(luci::FusedActFunc::RELU);
_add->dtype(loco::DataType::FLOAT32);
_add->shape({1, 4, 4, 1});

_add->name("add");
}

protected:
luci::CircleTransposeConv *_tconv = nullptr;
luci::CircleConst *_tconv_i = nullptr;
luci::CircleConst *_tconv_f = nullptr;
luci::CircleNode *_tconv_b = nullptr;
luci::CircleAdd *_add = nullptr;
luci::CircleConst *_add_c = nullptr;
luci::CircleFullyConnected *_extra_succ = nullptr;
luci::CircleConst *_extra_f = nullptr;
};

class FuseAddWithTConvTestGraph : public TestIOGraph, public TConvAddGraphlet
{
public:
void init(luci::FusedActFunc tconv_activation, bool use_bias, bool extra_successor)
{
TestIOGraph::init({1, 2, 2, 2}, {1, 4, 4, 1});
TConvAddGraphlet::init(g(), tconv_activation, use_bias, extra_successor);

_tconv->outBackprop(input());

output()->from(_add);
}
};

class FuseAddWithTConvPassTest : public ::testing::Test
{
public:
FuseAddWithTConvTestGraph g;
luci::FuseAddWithTConvPass pass;
};

} // namespace

TEST_F(FuseAddWithTConvPassTest, tconv_add_fuse)
{
g.init(luci::FusedActFunc::NONE, false /* use_bias */, false /* extra_successor */);

EXPECT_EQ(true, pass.run(g.g()));

auto relu = dynamic_cast<luci::CircleRelu *>(g.output()->from());
EXPECT_NE(nullptr, relu);
EXPECT_STREQ(relu->name().c_str(), "const_c/Relu");

auto tconv = dynamic_cast<luci::CircleTransposeConv *>(relu->features());
EXPECT_NE(nullptr, tconv);

auto bias = loco::must_cast<luci::CircleConst *>(tconv->bias());
EXPECT_NE(nullptr, bias);

for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
{
EXPECT_EQ(ADD_VAL, bias->at<loco::DataType::FLOAT32>(i));
}
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_bias_NEG)
{
g.init(luci::FusedActFunc::NONE, true /* use_bias */, false /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_activation_NEG)
{
g.init(luci::FusedActFunc::RELU, false /* use_bias */, false /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, tconv_with_extra_successor_NEG)
{
g.init(luci::FusedActFunc::NONE, false /* use_bias */, true /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseAddWithTConvPassTest, name)
{
luci::FuseAddWithTConvPass pass;
auto const name = pass.name();
Expand Down

0 comments on commit 543d655

Please sign in to comment.