From ffd74160426849582c2cefbdb682ab55070e5531 Mon Sep 17 00:00:00 2001 From: Hyukjin Jeong Date: Thu, 21 Nov 2024 07:11:42 +0900 Subject: [PATCH] [luci] Forward transpose over binary Op (#14323) This forwards transpose Ops over binary Op (add, mul). ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong --- .../luci/pass/src/ForwardTransposeOpPass.cpp | 120 +++++++++++++ .../pass/src/ForwardTransposeOpPass.test.cpp | 168 +++++++++++++++++- 2 files changed, 287 insertions(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp index b92c061450d..10665bbb5cc 100644 --- a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp @@ -57,6 +57,61 @@ bool check_perm(const CircleTranspose *t) return true; } +// Return vector of int32_t from CircleConst node +// Return empty vector if not supported +std::vector get_perm_data(const CircleConst *node) +{ + assert(node); // FIX_CALLER_UNLESS + std::vector perm_data; + switch (node->dtype()) + { + case loco::DataType::S32: + for (uint32_t i = 0; i < node->size(); i++) + { + auto data = node->at(i); + + // Unsupported + if (data < 0 or data >= static_cast(node->size())) + return {}; + + perm_data.emplace_back(data); + } + break; + // TODO Support S64 data type + default: + break; + } + + return perm_data; +} + +// Return true if below conditions are met +// 1. lhs->perm() and rhs->perm() are CircleConst +// 2. Both perm's values are the same +bool check_same_perm(const CircleTranspose *lhs, const CircleTranspose *rhs) +{ + auto lhs_perm = dynamic_cast(lhs->perm()); + if (not lhs_perm) + return false; + + auto rhs_perm = dynamic_cast(rhs->perm()); + if (not rhs_perm) + return false; + + std::vector lhs_perm_data = get_perm_data(lhs_perm); + if (lhs_perm_data.empty()) + return false; + + std::vector rhs_perm_data = get_perm_data(rhs_perm); + if (rhs_perm_data.empty()) + return false; + + if (lhs_perm_data != rhs_perm_data) + return false; + + return true; +} + // Create new Transpose Op including perm // Never return nullptr CircleTranspose *create_cloned_transpose(CircleTranspose *transpose) @@ -289,6 +344,46 @@ class EBOWithConstPattern final : public CircleNodeMutableVisitor } }; +// Elementwise Binary Operator (no const input) +class EBOPattern final : public CircleNodeMutableVisitor +{ +private: + template bool has_transpose_xy(CIRCLE_OP_PTR node) + { + luci::CircleTranspose *lhs = nullptr; + luci::CircleTranspose *rhs = nullptr; + + RETURN_FALSE_UNLESS(luci::fill(&lhs, &rhs).with_args_of(node)); + + // Check lhs's perm == rhs's perm + RETURN_FALSE_UNLESS(check_same_perm(lhs, rhs)); + + // Create cloned transpose + auto new_transpose = create_cloned_transpose(lhs); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + node->x(lhs->a()); + node->y(rhs->a()); + + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + +public: + // Default + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleAdd *node) { return has_transpose_xy(node); } + + bool visit(luci::CircleMul *node) { return has_transpose_xy(node); } +}; + // Elementwise Unary Operator class EwUnaryPattern final : public CircleNodeMutableVisitor { @@ -333,6 +428,16 @@ namespace luci /** * BEFORE + * + * [CircleNode] [CircleNode] + * | | + * [CircleTranspose] [CircleTranspose] + * | / + * [(BinaryOp)] + * | + * + * BinaryOp: CircleAdd, CircleMul, ... + * * | * [CircleNode] [CircleConst] * | / @@ -358,6 +463,16 @@ namespace luci * UnaryOp: CircleAbs, ... * * AFTER + * + * [CircleNode] [CircleNode] + * | / + * [(BinaryOp)] + * | + * [CircleTranspose] + * | + * + * BinaryOp: CircleAdd, CircleMul, ... + * * | * [CircleConst] [CircleNode] [CircleConst(updated)] * | / | / @@ -383,6 +498,9 @@ namespace luci bool ForwardTransposeOpPass::run(loco::Graph *g) { bool changed = false; + + // TODO Revisit pattern interface + EBOPattern ebo; EBOWithConstPattern eboc; EwUnaryPattern ewu; for (auto node : loco::active_nodes(loco::output_nodes(g))) @@ -392,6 +510,8 @@ bool ForwardTransposeOpPass::run(loco::Graph *g) changed = true; else if (circle_node->accept(&ewu)) changed = true; + else if (circle_node->accept(&ebo)) + changed = true; } return changed; } diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp index c3c502c98f0..00cffba8098 100644 --- a/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp @@ -30,6 +30,88 @@ namespace using namespace luci::test; +template class BothTransposeBinaryOpGraphlet +{ +public: + BothTransposeBinaryOpGraphlet() = default; + +public: + virtual ~BothTransposeBinaryOpGraphlet() = default; + +public: + // TODO Rename shape_in to shape_const + void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeI32 perm) + { + std::vector shape_in_v = shape_in; + std::vector perm_v = perm; + + _perm = g->nodes()->create(); + _transpose = g->nodes()->create(); + _binary = g->nodes()->create(); + + _perm->dtype(loco::DataType::S32); + _perm->rank(1); + _perm->dim(0).set(perm_v.size()); + _perm->shape_status(luci::ShapeStatus::VALID); + + // values + const auto size = perm_v.size(); + _perm->size(size); + for (uint32_t i = 0; i < size; i++) + _perm->at(i) = perm_v[i]; + + _perm->name("transpose_perm"); + _transpose->name("transpose"); + _binary->name("binary"); + } + + luci::CircleTranspose *transpose(void) { return _transpose; } + +protected: + luci::CircleTranspose *_transpose = nullptr; + T *_binary = nullptr; + luci::CircleConst *_perm = nullptr; +}; + +using BothTransposeAddGraphlet = BothTransposeBinaryOpGraphlet; +using BothTransposeMulGraphlet = BothTransposeBinaryOpGraphlet; + +class ForwardBothTransposeToAddGraph : public TestIOGraph, public BothTransposeAddGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out, const ShapeI32 shape_perm) + { + TestIOGraph::init(shape_in, shape_out); + BothTransposeAddGraphlet::init(g(), shape_in, shape_perm); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _binary->x(_transpose); + _binary->y(_transpose); + + output()->from(_binary); + } +}; + +class ForwardBothTransposeToMulGraph : public TestIOGraph, public BothTransposeMulGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out, const ShapeI32 shape_perm) + { + TestIOGraph::init(shape_in, shape_out); + BothTransposeMulGraphlet::init(g(), shape_in, shape_perm); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _binary->x(_transpose); + _binary->y(_transpose); + + output()->from(_binary); + } +}; + template class TransposeBinaryOpGraphlet { public: @@ -130,7 +212,7 @@ class ForwardTransposeToAddInvalidGraph : public TestIOGraph, public TransposeAd _transpose->a(input()); _transpose->perm(_perm); _binary->x(_transpose); - _binary->y(_transpose); + _binary->y(input()); output()->from(_binary); } @@ -204,6 +286,24 @@ void run_phase(loco::Graph *g) phase_runner.run(phase); } +class ForwardBothTransposeToAddGraphTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardBothTransposeToAddGraph _graph; +}; + +class ForwardBothTransposeToMulGraphTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardBothTransposeToMulGraph _graph; +}; + class ForwardTransposeToAddGraphTest : public ::testing::Test { public: @@ -314,6 +414,72 @@ TEST_F(ForwardTransposeToAddGraphTest, forward_add_yx) EXPECT_EQ(1, mul_const->dim(3).value()); } +TEST_F(ForwardBothTransposeToAddGraphTest, forward_add) +{ + _graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1}); + + run_pass(); + + auto transpose = dynamic_cast(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto add = dynamic_cast(transpose->a()); + EXPECT_NE(nullptr, add); + EXPECT_EQ(4, add->rank()); + EXPECT_EQ(1, add->dim(0).value()); + EXPECT_EQ(64, add->dim(1).value()); + EXPECT_EQ(51, add->dim(2).value()); + EXPECT_EQ(1, add->dim(3).value()); +} + +TEST_F(ForwardBothTransposeToAddGraphTest, forward_add_NEG) +{ + _graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1}); + + _graph.transpose()->perm(_graph.input()); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} + +TEST_F(ForwardBothTransposeToMulGraphTest, forward_mul) +{ + _graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1}); + + run_pass(); + + auto transpose = dynamic_cast(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto mul = dynamic_cast(transpose->a()); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(4, mul->rank()); + EXPECT_EQ(1, mul->dim(0).value()); + EXPECT_EQ(64, mul->dim(1).value()); + EXPECT_EQ(51, mul->dim(2).value()); + EXPECT_EQ(1, mul->dim(3).value()); +} + +TEST_F(ForwardBothTransposeToMulGraphTest, forward_mul_NEG) +{ + _graph.init({1, 64, 51, 1}, {1, 1, 51, 64}, {0, 3, 2, 1}); + + _graph.transpose()->perm(_graph.input()); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} + TEST_F(ForwardTransposeToMulGraphTest, forward_mul_xy) { _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});