Skip to content

Commit

Permalink
[luci] Forward transpose over binary Op (Samsung#14323)
Browse files Browse the repository at this point in the history
This forwards transpose Ops over binary Op (add, mul).

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Nov 20, 2024
1 parent 0746897 commit ffd7416
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 1 deletion.
120 changes: 120 additions & 0 deletions compiler/luci/pass/src/ForwardTransposeOpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,61 @@ bool check_perm(const CircleTranspose *t)
return true;
}

// Return vector of int32_t from CircleConst node
// Return empty vector if not supported
std::vector<int32_t> get_perm_data(const CircleConst *node)
{
assert(node); // FIX_CALLER_UNLESS
std::vector<int32_t> perm_data;
switch (node->dtype())
{
case loco::DataType::S32:
for (uint32_t i = 0; i < node->size<loco::DataType::S32>(); i++)
{
auto data = node->at<loco::DataType::S32>(i);

// Unsupported
if (data < 0 or data >= static_cast<int32_t>(node->size<loco::DataType::S32>()))
return {};

perm_data.emplace_back(data);
}
break;
// TODO Support S64 data type
default:
break;
}

return perm_data;
}

// Return true if below conditions are met
// 1. lhs->perm() and rhs->perm() are CircleConst
// 2. Both perm's values are the same
bool check_same_perm(const CircleTranspose *lhs, const CircleTranspose *rhs)
{
auto lhs_perm = dynamic_cast<CircleConst *>(lhs->perm());
if (not lhs_perm)
return false;

auto rhs_perm = dynamic_cast<CircleConst *>(rhs->perm());
if (not rhs_perm)
return false;

std::vector<int32_t> lhs_perm_data = get_perm_data(lhs_perm);
if (lhs_perm_data.empty())
return false;

std::vector<int32_t> rhs_perm_data = get_perm_data(rhs_perm);
if (rhs_perm_data.empty())
return false;

if (lhs_perm_data != rhs_perm_data)
return false;

return true;
}

// Create new Transpose Op including perm
// Never return nullptr
CircleTranspose *create_cloned_transpose(CircleTranspose *transpose)
Expand Down Expand Up @@ -289,6 +344,46 @@ class EBOWithConstPattern final : public CircleNodeMutableVisitor<bool>
}
};

// Elementwise Binary Operator (no const input)
class EBOPattern final : public CircleNodeMutableVisitor<bool>
{
private:
template <typename CIRCLE_OP_PTR> bool has_transpose_xy(CIRCLE_OP_PTR node)
{
luci::CircleTranspose *lhs = nullptr;
luci::CircleTranspose *rhs = nullptr;

RETURN_FALSE_UNLESS(luci::fill(&lhs, &rhs).with_args_of(node));

// Check lhs's perm == rhs's perm
RETURN_FALSE_UNLESS(check_same_perm(lhs, rhs));

// Create cloned transpose
auto new_transpose = create_cloned_transpose(lhs);
assert(new_transpose); // FIX_ME_UNLESS

// Reconnect network
node->x(lhs->a());
node->y(rhs->a());

loco::replace(node).with(new_transpose);
new_transpose->a(node);

// Do shape inference for this node again.
node->shape_status(luci::ShapeStatus::UNDEFINED);

return true;
}

public:
// Default
bool visit(luci::CircleNode *) { return false; }

bool visit(luci::CircleAdd *node) { return has_transpose_xy(node); }

bool visit(luci::CircleMul *node) { return has_transpose_xy(node); }
};

// Elementwise Unary Operator
class EwUnaryPattern final : public CircleNodeMutableVisitor<bool>
{
Expand Down Expand Up @@ -333,6 +428,16 @@ namespace luci

/**
* BEFORE
*
* [CircleNode] [CircleNode]
* | |
* [CircleTranspose] [CircleTranspose]
* | /
* [(BinaryOp)]
* |
*
* BinaryOp: CircleAdd, CircleMul, ...
*
* |
* [CircleNode] [CircleConst]
* | /
Expand All @@ -358,6 +463,16 @@ namespace luci
* UnaryOp: CircleAbs, ...
*
* AFTER
*
* [CircleNode] [CircleNode]
* | /
* [(BinaryOp)]
* |
* [CircleTranspose]
* |
*
* BinaryOp: CircleAdd, CircleMul, ...
*
* |
* [CircleConst] [CircleNode] [CircleConst(updated)]
* | / | /
Expand All @@ -383,6 +498,9 @@ namespace luci
bool ForwardTransposeOpPass::run(loco::Graph *g)
{
bool changed = false;

// TODO Revisit pattern interface
EBOPattern ebo;
EBOWithConstPattern eboc;
EwUnaryPattern ewu;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
Expand All @@ -392,6 +510,8 @@ bool ForwardTransposeOpPass::run(loco::Graph *g)
changed = true;
else if (circle_node->accept(&ewu))
changed = true;
else if (circle_node->accept(&ebo))
changed = true;
}
return changed;
}
Expand Down
168 changes: 167 additions & 1 deletion compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,88 @@ namespace

using namespace luci::test;

template <typename T> class BothTransposeBinaryOpGraphlet
{
public:
BothTransposeBinaryOpGraphlet() = default;

public:
virtual ~BothTransposeBinaryOpGraphlet() = default;

public:
// TODO Rename shape_in to shape_const
void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeI32 perm)
{
std::vector<uint32_t> shape_in_v = shape_in;
std::vector<int32_t> perm_v = perm;

_perm = g->nodes()->create<luci::CircleConst>();
_transpose = g->nodes()->create<luci::CircleTranspose>();
_binary = g->nodes()->create<T>();

_perm->dtype(loco::DataType::S32);
_perm->rank(1);
_perm->dim(0).set(perm_v.size());
_perm->shape_status(luci::ShapeStatus::VALID);

// values
const auto size = perm_v.size();
_perm->size<loco::DataType::S32>(size);
for (uint32_t i = 0; i < size; i++)
_perm->at<loco::DataType::S32>(i) = perm_v[i];

_perm->name("transpose_perm");
_transpose->name("transpose");
_binary->name("binary");
}

luci::CircleTranspose *transpose(void) { return _transpose; }

protected:
luci::CircleTranspose *_transpose = nullptr;
T *_binary = nullptr;
luci::CircleConst *_perm = nullptr;
};

using BothTransposeAddGraphlet = BothTransposeBinaryOpGraphlet<luci::CircleAdd>;
using BothTransposeMulGraphlet = BothTransposeBinaryOpGraphlet<luci::CircleMul>;

class ForwardBothTransposeToAddGraph : public TestIOGraph, public BothTransposeAddGraphlet
{
public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out, const ShapeI32 shape_perm)
{
TestIOGraph::init(shape_in, shape_out);
BothTransposeAddGraphlet::init(g(), shape_in, shape_perm);

// connect network
_transpose->a(input());
_transpose->perm(_perm);
_binary->x(_transpose);
_binary->y(_transpose);

output()->from(_binary);
}
};

class ForwardBothTransposeToMulGraph : public TestIOGraph, public BothTransposeMulGraphlet
{
public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out, const ShapeI32 shape_perm)
{
TestIOGraph::init(shape_in, shape_out);
BothTransposeMulGraphlet::init(g(), shape_in, shape_perm);

// connect network
_transpose->a(input());
_transpose->perm(_perm);
_binary->x(_transpose);
_binary->y(_transpose);

output()->from(_binary);
}
};

template <typename T> class TransposeBinaryOpGraphlet
{
public:
Expand Down Expand Up @@ -130,7 +212,7 @@ class ForwardTransposeToAddInvalidGraph : public TestIOGraph, public TransposeAd
_transpose->a(input());
_transpose->perm(_perm);
_binary->x(_transpose);
_binary->y(_transpose);
_binary->y(input());

output()->from(_binary);
}
Expand Down Expand Up @@ -204,6 +286,24 @@ void run_phase(loco::Graph *g)
phase_runner.run(phase);
}

class ForwardBothTransposeToAddGraphTest : public ::testing::Test
{
public:
void run_pass(void) { run_phase(_graph.g()); }

protected:
ForwardBothTransposeToAddGraph _graph;
};

class ForwardBothTransposeToMulGraphTest : public ::testing::Test
{
public:
void run_pass(void) { run_phase(_graph.g()); }

protected:
ForwardBothTransposeToMulGraph _graph;
};

class ForwardTransposeToAddGraphTest : public ::testing::Test
{
public:
Expand Down Expand Up @@ -314,6 +414,72 @@ TEST_F(ForwardTransposeToAddGraphTest, forward_add_yx)
EXPECT_EQ(1, mul_const->dim(3).value());
}

TEST_F(ForwardBothTransposeToAddGraphTest, forward_add)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

run_pass();

auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
EXPECT_NE(nullptr, transpose);
EXPECT_EQ(4, transpose->rank());
EXPECT_EQ(1, transpose->dim(0).value());
EXPECT_EQ(1, transpose->dim(1).value());
EXPECT_EQ(51, transpose->dim(2).value());
EXPECT_EQ(64, transpose->dim(3).value());

auto add = dynamic_cast<luci::CircleAdd *>(transpose->a());
EXPECT_NE(nullptr, add);
EXPECT_EQ(4, add->rank());
EXPECT_EQ(1, add->dim(0).value());
EXPECT_EQ(64, add->dim(1).value());
EXPECT_EQ(51, add->dim(2).value());
EXPECT_EQ(1, add->dim(3).value());
}

TEST_F(ForwardBothTransposeToAddGraphTest, forward_add_NEG)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

_graph.transpose()->perm(_graph.input());

luci::ForwardTransposeOpPass pass;
EXPECT_FALSE(pass.run(_graph.g()));
}

TEST_F(ForwardBothTransposeToMulGraphTest, forward_mul)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

run_pass();

auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
EXPECT_NE(nullptr, transpose);
EXPECT_EQ(4, transpose->rank());
EXPECT_EQ(1, transpose->dim(0).value());
EXPECT_EQ(1, transpose->dim(1).value());
EXPECT_EQ(51, transpose->dim(2).value());
EXPECT_EQ(64, transpose->dim(3).value());

auto mul = dynamic_cast<luci::CircleMul *>(transpose->a());
EXPECT_NE(nullptr, mul);
EXPECT_EQ(4, mul->rank());
EXPECT_EQ(1, mul->dim(0).value());
EXPECT_EQ(64, mul->dim(1).value());
EXPECT_EQ(51, mul->dim(2).value());
EXPECT_EQ(1, mul->dim(3).value());
}

TEST_F(ForwardBothTransposeToMulGraphTest, forward_mul_NEG)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

_graph.transpose()->perm(_graph.input());

luci::ForwardTransposeOpPass pass;
EXPECT_FALSE(pass.run(_graph.g()));
}

TEST_F(ForwardTransposeToMulGraphTest, forward_mul_xy)
{
_graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
Expand Down

0 comments on commit ffd7416

Please sign in to comment.