Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci] Forward transpose over binary Op #14323

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions compiler/luci/pass/src/ForwardTransposeOpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,61 @@ bool check_perm(const CircleTranspose *t)
return true;
}

// Return vector of int32_t from CircleConst node
// Return empty vector if not supported
std::vector<int32_t> get_perm_data(const CircleConst *node)
{
assert(node); // FIX_CALLER_UNLESS
std::vector<int32_t> perm_data;
switch (node->dtype())
{
case loco::DataType::S32:
for (uint32_t i = 0; i < node->size<loco::DataType::S32>(); i++)
{
auto data = node->at<loco::DataType::S32>(i);

// Unsupported
if (data < 0 or data >= static_cast<int32_t>(node->size<loco::DataType::S32>()))
return {};

perm_data.emplace_back(data);
}
break;
// TODO Support S64 data type
default:
break;
}

return perm_data;
}

// Return true if below conditions are met
// 1. lhs->perm() and rhs->perm() are CircleConst
// 2. Both perm's values are the same
bool check_same_perm(const CircleTranspose *lhs, const CircleTranspose *rhs)
{
auto lhs_perm = dynamic_cast<CircleConst *>(lhs->perm());
if (not lhs_perm)
return false;
Comment on lines +94 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(optional) maybe use RETURN_FALSE_UNLESS ?


auto rhs_perm = dynamic_cast<CircleConst *>(rhs->perm());
if (not rhs_perm)
return false;

std::vector<int32_t> lhs_perm_data = get_perm_data(lhs_perm);
if (lhs_perm_data.empty())
return false;

std::vector<int32_t> rhs_perm_data = get_perm_data(rhs_perm);
if (rhs_perm_data.empty())
return false;

if (lhs_perm_data != rhs_perm_data)
return false;

return true;
}

// Create new Transpose Op including perm
// Never return nullptr
CircleTranspose *create_cloned_transpose(CircleTranspose *transpose)
Expand Down Expand Up @@ -289,6 +344,46 @@ class EBOWithConstPattern final : public CircleNodeMutableVisitor<bool>
}
};

// Elementwise Binary Operator (no const input)
class EBOPattern final : public CircleNodeMutableVisitor<bool>
{
private:
template <typename CIRCLE_OP_PTR> bool has_transpose_xy(CIRCLE_OP_PTR node)
{
luci::CircleTranspose *lhs = nullptr;
luci::CircleTranspose *rhs = nullptr;

RETURN_FALSE_UNLESS(luci::fill(&lhs, &rhs).with_args_of(node));

// Check lhs's perm == rhs's perm
RETURN_FALSE_UNLESS(check_same_perm(lhs, rhs));

// Create cloned transpose
auto new_transpose = create_cloned_transpose(lhs);
assert(new_transpose); // FIX_ME_UNLESS

// Reconnect network
node->x(lhs->a());
node->y(rhs->a());

loco::replace(node).with(new_transpose);
new_transpose->a(node);

// Do shape inference for this node again.
node->shape_status(luci::ShapeStatus::UNDEFINED);

return true;
}

public:
// Default
bool visit(luci::CircleNode *) { return false; }

bool visit(luci::CircleAdd *node) { return has_transpose_xy(node); }

bool visit(luci::CircleMul *node) { return has_transpose_xy(node); }
};

// Elementwise Unary Operator
class EwUnaryPattern final : public CircleNodeMutableVisitor<bool>
{
Expand Down Expand Up @@ -333,6 +428,16 @@ namespace luci

/**
* BEFORE
*
* [CircleNode] [CircleNode]
* | |
* [CircleTranspose] [CircleTranspose]
* | /
* [(BinaryOp)]
* |
*
* BinaryOp: CircleAdd, CircleMul, ...
*
* |
* [CircleNode] [CircleConst]
* | /
Expand All @@ -358,6 +463,16 @@ namespace luci
* UnaryOp: CircleAbs, ...
*
* AFTER
*
* [CircleNode] [CircleNode]
* | /
* [(BinaryOp)]
* |
* [CircleTranspose]
* |
*
* BinaryOp: CircleAdd, CircleMul, ...
*
* |
* [CircleConst] [CircleNode] [CircleConst(updated)]
* | / | /
Expand All @@ -383,6 +498,9 @@ namespace luci
bool ForwardTransposeOpPass::run(loco::Graph *g)
{
bool changed = false;

// TODO Revisit pattern interface
EBOPattern ebo;
EBOWithConstPattern eboc;
EwUnaryPattern ewu;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
Expand All @@ -392,6 +510,8 @@ bool ForwardTransposeOpPass::run(loco::Graph *g)
changed = true;
else if (circle_node->accept(&ewu))
changed = true;
else if (circle_node->accept(&ebo))
changed = true;
}
return changed;
}
Expand Down
168 changes: 167 additions & 1 deletion compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,88 @@ namespace

using namespace luci::test;

template <typename T> class BothTransposeBinaryOpGraphlet
{
public:
BothTransposeBinaryOpGraphlet() = default;

public:
virtual ~BothTransposeBinaryOpGraphlet() = default;

public:
// TODO Rename shape_in to shape_const
void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeI32 perm)
{
std::vector<uint32_t> shape_in_v = shape_in;
std::vector<int32_t> perm_v = perm;

_perm = g->nodes()->create<luci::CircleConst>();
_transpose = g->nodes()->create<luci::CircleTranspose>();
_binary = g->nodes()->create<T>();

_perm->dtype(loco::DataType::S32);
_perm->rank(1);
_perm->dim(0).set(perm_v.size());
_perm->shape_status(luci::ShapeStatus::VALID);

// values
const auto size = perm_v.size();
_perm->size<loco::DataType::S32>(size);
for (uint32_t i = 0; i < size; i++)
_perm->at<loco::DataType::S32>(i) = perm_v[i];

_perm->name("transpose_perm");
_transpose->name("transpose");
_binary->name("binary");
}

luci::CircleTranspose *transpose(void) { return _transpose; }

protected:
luci::CircleTranspose *_transpose = nullptr;
T *_binary = nullptr;
luci::CircleConst *_perm = nullptr;
};

using BothTransposeAddGraphlet = BothTransposeBinaryOpGraphlet<luci::CircleAdd>;
using BothTransposeMulGraphlet = BothTransposeBinaryOpGraphlet<luci::CircleMul>;

class ForwardBothTransposeToAddGraph : public TestIOGraph, public BothTransposeAddGraphlet
{
public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out, const ShapeI32 shape_perm)
{
TestIOGraph::init(shape_in, shape_out);
BothTransposeAddGraphlet::init(g(), shape_in, shape_perm);

// connect network
_transpose->a(input());
_transpose->perm(_perm);
_binary->x(_transpose);
_binary->y(_transpose);

output()->from(_binary);
}
};

class ForwardBothTransposeToMulGraph : public TestIOGraph, public BothTransposeMulGraphlet
{
public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out, const ShapeI32 shape_perm)
{
TestIOGraph::init(shape_in, shape_out);
BothTransposeMulGraphlet::init(g(), shape_in, shape_perm);

// connect network
_transpose->a(input());
_transpose->perm(_perm);
_binary->x(_transpose);
_binary->y(_transpose);

output()->from(_binary);
}
};

template <typename T> class TransposeBinaryOpGraphlet
{
public:
Expand Down Expand Up @@ -130,7 +212,7 @@ class ForwardTransposeToAddInvalidGraph : public TestIOGraph, public TransposeAd
_transpose->a(input());
_transpose->perm(_perm);
_binary->x(_transpose);
_binary->y(_transpose);
_binary->y(input());

output()->from(_binary);
}
Expand Down Expand Up @@ -204,6 +286,24 @@ void run_phase(loco::Graph *g)
phase_runner.run(phase);
}

class ForwardBothTransposeToAddGraphTest : public ::testing::Test
{
public:
void run_pass(void) { run_phase(_graph.g()); }

protected:
ForwardBothTransposeToAddGraph _graph;
};

class ForwardBothTransposeToMulGraphTest : public ::testing::Test
{
public:
void run_pass(void) { run_phase(_graph.g()); }

protected:
ForwardBothTransposeToMulGraph _graph;
};

class ForwardTransposeToAddGraphTest : public ::testing::Test
{
public:
Expand Down Expand Up @@ -314,6 +414,72 @@ TEST_F(ForwardTransposeToAddGraphTest, forward_add_yx)
EXPECT_EQ(1, mul_const->dim(3).value());
}

TEST_F(ForwardBothTransposeToAddGraphTest, forward_add)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

run_pass();

auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
EXPECT_NE(nullptr, transpose);
EXPECT_EQ(4, transpose->rank());
EXPECT_EQ(1, transpose->dim(0).value());
EXPECT_EQ(1, transpose->dim(1).value());
EXPECT_EQ(51, transpose->dim(2).value());
EXPECT_EQ(64, transpose->dim(3).value());

auto add = dynamic_cast<luci::CircleAdd *>(transpose->a());
EXPECT_NE(nullptr, add);
EXPECT_EQ(4, add->rank());
EXPECT_EQ(1, add->dim(0).value());
EXPECT_EQ(64, add->dim(1).value());
EXPECT_EQ(51, add->dim(2).value());
EXPECT_EQ(1, add->dim(3).value());
}

TEST_F(ForwardBothTransposeToAddGraphTest, forward_add_NEG)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

_graph.transpose()->perm(_graph.input());

luci::ForwardTransposeOpPass pass;
EXPECT_FALSE(pass.run(_graph.g()));
}

TEST_F(ForwardBothTransposeToMulGraphTest, forward_mul)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

run_pass();

auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
EXPECT_NE(nullptr, transpose);
EXPECT_EQ(4, transpose->rank());
EXPECT_EQ(1, transpose->dim(0).value());
EXPECT_EQ(1, transpose->dim(1).value());
EXPECT_EQ(51, transpose->dim(2).value());
EXPECT_EQ(64, transpose->dim(3).value());

auto mul = dynamic_cast<luci::CircleMul *>(transpose->a());
EXPECT_NE(nullptr, mul);
EXPECT_EQ(4, mul->rank());
EXPECT_EQ(1, mul->dim(0).value());
EXPECT_EQ(64, mul->dim(1).value());
EXPECT_EQ(51, mul->dim(2).value());
EXPECT_EQ(1, mul->dim(3).value());
}

TEST_F(ForwardBothTransposeToMulGraphTest, forward_mul_NEG)
{
_graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1});

_graph.transpose()->perm(_graph.input());

luci::ForwardTransposeOpPass pass;
EXPECT_FALSE(pass.run(_graph.g()));
}

TEST_F(ForwardTransposeToMulGraphTest, forward_mul_xy)
{
_graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
Expand Down