Skip to content

Commit

Permalink
[luci/pass] Support MirrorPad in ConvertNCHWToNHWC (#13511)
Browse files Browse the repository at this point in the history
This supports MirrorPad in ConvertNCHWToNHWC.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Jul 26, 2024
1 parent 09b3fff commit cdd5a91
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
54 changes: 54 additions & 0 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,34 @@ bool is_NCHW(const luci::CirclePadV2 *node)
return true;
}

// NOTE Copied from is_NCHW(CirclePad)
bool is_NCHW(const luci::CircleMirrorPad *node)
{
const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
// Non-const paddings is not supported
if (paddings == nullptr)
return false;

if (paddings->rank() != 2)
return false;

if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
return false;

// Only check the first two dimensions
for (uint32_t dim = 0; dim < 2; dim++)
{
for (uint32_t i = 0; i < 2; i++)
{
auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
if (data != 0)
return false;
}
}

return true;
}

bool is_const(const loco::Node *node)
{
if (not dynamic_cast<const luci::CircleConst *>(node))
Expand Down Expand Up @@ -1000,6 +1028,31 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}

bool visit(luci::CircleMirrorPad *node)
{
if (!is_NCHW(node))
return false;

const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
auto pre_trans = create_pre_transpose(node);
pre_trans->a(pred_node);
node->input(pre_trans);

auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
node->paddings(nhwc_paddings);

// Do shape inference for this node again.
node->shape_status(luci::ShapeStatus::UNDEFINED);

auto post_trans = create_post_transpose(node);
loco::replace(node).with(post_trans);

post_trans->a(node);

return true;
}

bool visit(luci::CircleMul *node)
{
LOGGER(l);
Expand Down Expand Up @@ -1532,6 +1585,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
case luci::CircleOpcode::MAXIMUM:
case luci::CircleOpcode::MEAN:
case luci::CircleOpcode::MINIMUM:
case luci::CircleOpcode::MIRROR_PAD:
case luci::CircleOpcode::MUL:
case luci::CircleOpcode::NEG:
case luci::CircleOpcode::PAD:
Expand Down
83 changes: 83 additions & 0 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,55 @@ class MinimumGraph final : public SimpleGraph
luci::CircleConst *limit = nullptr;
};

class MirrorPadGraph final : public SimpleGraph
{
protected:
loco::Node *insertGraphBody(loco::Node *input) override
{
pad = g.nodes()->create<luci::CircleMirrorPad>();
paddings = g.nodes()->create<luci::CircleConst>();

pad->dtype(loco::DataType::FLOAT32);
paddings->dtype(loco::DataType::S32);

uint32_t channel_size = 16;
pad->shape({1, channel_size, 4, 4});
paddings->shape({4, 2});

pad->mode(luci::MirrorPadMode::REFLECT);

// paddings data (NCHW)
// [[0,0], [0,0], [1,1], [2,2]]
paddings->size<loco::DataType::S32>(8);
for (uint32_t dim = 0; dim < 4; dim++)
{
for (uint32_t i = 0; i < 2; i++)
{
int32_t data = 0;

if (dim == 2)
data = 1;
else if (dim == 3)
data = 2;

paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
}
}

pad->input(input);
pad->paddings(paddings);

pad->name("pad");
paddings->name("paddings");

return pad;
}

public:
luci::CircleMirrorPad *pad = nullptr;
luci::CircleConst *paddings = nullptr;
};

class MulGraph final : public SimpleGraph
{
protected:
Expand Down Expand Up @@ -1606,6 +1655,40 @@ TEST(ConvertNCHWToNHWC, Minimum_non_scalar_NEG)
EXPECT_FALSE(pass.run(&g.g));
}

TEST(ConvertNCHWToNHWC, MirrorPad)
{
MirrorPadGraph g;
g.init();

run_phase(&g.g, false, false);

auto input_succs = loco::succs(g.input);
EXPECT_EQ(1, input_succs.size());
check_post_trans(*input_succs.begin());

check_pre_trans(g.pad->input());

auto pad_succs = loco::succs(g.pad);
EXPECT_EQ(1, pad_succs.size());
check_post_trans(*pad_succs.begin());

auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
EXPECT_NE(nullptr, new_paddings);
EXPECT_EQ(2, new_paddings->rank());
EXPECT_EQ(4, new_paddings->dim(0).value());
EXPECT_EQ(2, new_paddings->dim(1).value());
EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));

check_pre_trans(g.output->from());
}

TEST(ConvertNCHWToNHWC, Mul)
{
MulGraph g;
Expand Down

0 comments on commit cdd5a91

Please sign in to comment.