Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Canonicalize PadV2 paddings #13542

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions compiler/luci/pass/src/CanonicalizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,51 @@ bool paddings_to_s32(luci::CirclePad *pad)
return true;
}

/**
* Convert S64 CircleConst paddings to S32
*/
bool paddings_to_s32(luci::CirclePadV2 *padv2)
{
// check conditions
auto paddings = dynamic_cast<luci::CircleConst *>(padv2->paddings());
CHECK_OR_FALSE(paddings);
CHECK_OR_FALSE(paddings->dtype() == loco::DataType::S64);

// TODO relocate to helpers/CreateCircleConst.h when necessary
auto num_elements = paddings->size<loco::DataType::S64>();
auto hval = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
auto lval = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
for (uint32_t i = 0; i < num_elements; i++)
{
auto v64 = paddings->at<loco::DataType::S64>(i);
CHECK_OR_FALSE(v64 <= hval);
CHECK_OR_FALSE(v64 >= lval);
}

auto paddings_s32 = padv2->graph()->nodes()->create<luci::CircleConst>();
paddings_s32->name(paddings->name() + "_S32");
paddings_s32->dtype(loco::DataType::S32);
paddings_s32->rank(paddings->rank());
for (uint32_t i = 0; i < paddings->rank(); i++)
paddings_s32->dim(i).set(paddings->dim(i).value());
paddings_s32->shape_status(luci::ShapeStatus::VALID);
luci::add_origin(paddings_s32, luci::get_origin(paddings));

paddings_s32->size<loco::DataType::S32>(num_elements);
for (uint32_t i = 0; i < num_elements; i++)
{
auto v64 = paddings->at<loco::DataType::S64>(i);
paddings_s32->at<loco::DataType::S32>(i) = static_cast<int32_t>(v64);
}

// replace paddings with S32 dtype
padv2->paddings(paddings_s32);

return true;
}

// TODO merge both paddings_to_s32 with template
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask what the meaning of both paddings in this comment is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two paddings_to_s32(), one for Pad and another for PadV2.
changes in this PR for PadV2 is copied from Pad.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!


} // namespace

namespace luci
Expand All @@ -91,6 +136,11 @@ bool CanonicalizePass::run(loco::Graph *g)
if (paddings_to_s32(pad))
changed = true;
}
else if (auto padv2 = dynamic_cast<luci::CirclePadV2 *>(node))
{
if (paddings_to_s32(padv2))
changed = true;
}

// TODO add more canonicalization
}
Expand Down
74 changes: 74 additions & 0 deletions compiler/luci/pass/src/CanonicalizePass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ struct PadGraphlet
void init(loco::Graph *g)
{
_pad = g->nodes()->create<luci::CirclePad>();
_padv2 = g->nodes()->create<luci::CirclePadV2>();
_paddings_s32 = g->nodes()->create<luci::CircleConst>();
_paddings_s64 = g->nodes()->create<luci::CircleConst>();
// NOTE PadV2.constant_values is not set as test doesn't use this

_pad->name("pad");
_padv2->name("padv2");
_paddings_s32->name("paddings_s32");
_paddings_s64->name("paddings_s64");

Expand Down Expand Up @@ -75,6 +78,7 @@ struct PadGraphlet
}

luci::CirclePad *_pad = nullptr;
luci::CirclePadV2 *_padv2 = nullptr;
luci::CircleConst *_paddings_s32 = nullptr;
luci::CircleConst *_paddings_s64 = nullptr;
};
Expand All @@ -96,6 +100,23 @@ class CanonicalizePadTestGraph : public TestIOGraph, public PadGraphlet
}
};

class CanonicalizePadV2TestGraph : public TestIOGraph, public PadGraphlet
{
public:
CanonicalizePadV2TestGraph() = default;

void init(void)
{
TestIOGraph::init({1, 3, 3, 2}, {1, 5, 5, 2});
PadGraphlet::init(g());

_padv2->input(input());
_padv2->paddings(_paddings_s64);

output()->from(_padv2);
}
};

} // namespace

TEST(CanonicalizePassPadTest, paddings_64_to_32)
Expand Down Expand Up @@ -150,3 +171,56 @@ TEST(CanonicalizePassPadTest, paddings_32_over_NEG)
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);
}

TEST(CanonicalizePassPadV2Test, paddings_64_to_32)
{
CanonicalizePadV2TestGraph g;
luci::CanonicalizePass pass;

g.init();

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);

EXPECT_TRUE(pass.run(g.g()));

paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);
}

TEST(CanonicalizePassPadV2Test, paddings_32_NEG)
{
CanonicalizePadV2TestGraph g;
luci::CanonicalizePass pass;

g.init();
g._padv2->paddings(g._paddings_s32);

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);

EXPECT_FALSE(pass.run(g.g()));

paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S32);
}

TEST(CanonicalizePassPadV2Test, paddings_32_over_NEG)
{
CanonicalizePadV2TestGraph g;
luci::CanonicalizePass pass;

g.init();
g._paddings_s64->at<loco::DataType::S64>(2) =
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 100;

EXPECT_FALSE(pass.run(g.g()));

luci::CircleConst *paddings = dynamic_cast<luci::CircleConst *>(g._padv2->paddings());
EXPECT_NE(nullptr, paddings);
EXPECT_EQ(paddings->dtype(), loco::DataType::S64);
}