From ad2feea7c4be15fa3d5365de02105c815ec7e8be Mon Sep 17 00:00:00 2001 From: SaeHie Park Date: Mon, 29 Jul 2024 06:23:16 +0000 Subject: [PATCH] [luci/pass] Canonicalize PadV2 paddings This will enable Canonicalize PadV2 paddings dtype. ONE-DCO-1.0-Signed-off-by: SaeHie Park --- compiler/luci/pass/src/CanonicalizePass.cpp | 50 +++++++++++++ .../luci/pass/src/CanonicalizePass.test.cpp | 74 +++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/compiler/luci/pass/src/CanonicalizePass.cpp b/compiler/luci/pass/src/CanonicalizePass.cpp index ba774656782..d4959da7e08 100644 --- a/compiler/luci/pass/src/CanonicalizePass.cpp +++ b/compiler/luci/pass/src/CanonicalizePass.cpp @@ -73,6 +73,51 @@ bool paddings_to_s32(luci::CirclePad *pad) return true; } +/** + * Convert S64 CircleConst paddings to S32 + */ +bool paddings_to_s32(luci::CirclePadV2 *padv2) +{ + // check conditions + auto paddings = dynamic_cast(padv2->paddings()); + CHECK_OR_FALSE(paddings); + CHECK_OR_FALSE(paddings->dtype() == loco::DataType::S64); + + // TODO relocate to helpers/CreateCircleConst.h when necessary + auto num_elements = paddings->size(); + auto hval = static_cast(std::numeric_limits::max()); + auto lval = static_cast(std::numeric_limits::lowest()); + for (uint32_t i = 0; i < num_elements; i++) + { + auto v64 = paddings->at(i); + CHECK_OR_FALSE(v64 <= hval); + CHECK_OR_FALSE(v64 >= lval); + } + + auto paddings_s32 = padv2->graph()->nodes()->create(); + paddings_s32->name(paddings->name() + "_S32"); + paddings_s32->dtype(loco::DataType::S32); + paddings_s32->rank(paddings->rank()); + for (uint32_t i = 0; i < paddings->rank(); i++) + paddings_s32->dim(i).set(paddings->dim(i).value()); + paddings_s32->shape_status(luci::ShapeStatus::VALID); + luci::add_origin(paddings_s32, luci::get_origin(paddings)); + + paddings_s32->size(num_elements); + for (uint32_t i = 0; i < num_elements; i++) + { + auto v64 = paddings->at(i); + paddings_s32->at(i) = static_cast(v64); + } + + // replace paddings with S32 dtype + padv2->paddings(paddings_s32); + + return true; +} + +// TODO merge both paddings_to_s32 with template + } // namespace namespace luci @@ -91,6 +136,11 @@ bool CanonicalizePass::run(loco::Graph *g) if (paddings_to_s32(pad)) changed = true; } + else if (auto padv2 = dynamic_cast(node)) + { + if (paddings_to_s32(padv2)) + changed = true; + } // TODO add more canonicalization } diff --git a/compiler/luci/pass/src/CanonicalizePass.test.cpp b/compiler/luci/pass/src/CanonicalizePass.test.cpp index 3e87935fbdb..8640bbb6d97 100644 --- a/compiler/luci/pass/src/CanonicalizePass.test.cpp +++ b/compiler/luci/pass/src/CanonicalizePass.test.cpp @@ -43,10 +43,13 @@ struct PadGraphlet void init(loco::Graph *g) { _pad = g->nodes()->create(); + _padv2 = g->nodes()->create(); _paddings_s32 = g->nodes()->create(); _paddings_s64 = g->nodes()->create(); + // NOTE PadV2.constant_values is not set as test doesn't use this _pad->name("pad"); + _padv2->name("padv2"); _paddings_s32->name("paddings_s32"); _paddings_s64->name("paddings_s64"); @@ -75,6 +78,7 @@ struct PadGraphlet } luci::CirclePad *_pad = nullptr; + luci::CirclePadV2 *_padv2 = nullptr; luci::CircleConst *_paddings_s32 = nullptr; luci::CircleConst *_paddings_s64 = nullptr; }; @@ -96,6 +100,23 @@ class CanonicalizePadTestGraph : public TestIOGraph, public PadGraphlet } }; +class CanonicalizePadV2TestGraph : public TestIOGraph, public PadGraphlet +{ +public: + CanonicalizePadV2TestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 3, 3, 2}, {1, 5, 5, 2}); + PadGraphlet::init(g()); + + _padv2->input(input()); + _padv2->paddings(_paddings_s64); + + output()->from(_padv2); + } +}; + } // namespace TEST(CanonicalizePassPadTest, paddings_64_to_32) @@ -150,3 +171,56 @@ TEST(CanonicalizePassPadTest, paddings_32_over_NEG) EXPECT_NE(nullptr, paddings); EXPECT_EQ(paddings->dtype(), loco::DataType::S64); } + +TEST(CanonicalizePassPadV2Test, paddings_64_to_32) +{ + CanonicalizePadV2TestGraph g; + luci::CanonicalizePass pass; + + g.init(); + + luci::CircleConst *paddings = dynamic_cast(g._padv2->paddings()); + EXPECT_NE(nullptr, paddings); + EXPECT_EQ(paddings->dtype(), loco::DataType::S64); + + EXPECT_TRUE(pass.run(g.g())); + + paddings = dynamic_cast(g._padv2->paddings()); + EXPECT_NE(nullptr, paddings); + EXPECT_EQ(paddings->dtype(), loco::DataType::S32); +} + +TEST(CanonicalizePassPadV2Test, paddings_32_NEG) +{ + CanonicalizePadV2TestGraph g; + luci::CanonicalizePass pass; + + g.init(); + g._padv2->paddings(g._paddings_s32); + + luci::CircleConst *paddings = dynamic_cast(g._padv2->paddings()); + EXPECT_NE(nullptr, paddings); + EXPECT_EQ(paddings->dtype(), loco::DataType::S32); + + EXPECT_FALSE(pass.run(g.g())); + + paddings = dynamic_cast(g._padv2->paddings()); + EXPECT_NE(nullptr, paddings); + EXPECT_EQ(paddings->dtype(), loco::DataType::S32); +} + +TEST(CanonicalizePassPadV2Test, paddings_32_over_NEG) +{ + CanonicalizePadV2TestGraph g; + luci::CanonicalizePass pass; + + g.init(); + g._paddings_s64->at(2) = + static_cast(std::numeric_limits::max()) + 100; + + EXPECT_FALSE(pass.run(g.g())); + + luci::CircleConst *paddings = dynamic_cast(g._padv2->paddings()); + EXPECT_NE(nullptr, paddings); + EXPECT_EQ(paddings->dtype(), loco::DataType::S64); +}