Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Support Div in ConvertNCHWToNHWC #13525

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,40 @@ bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_nod
return true;
}

// TODO Merge this function with other elementwise Ops
bool is_NCHW_with_const(const luci::CircleDiv *node, luci::CircleNode *&pred_node,
luci::CircleConst *&const_node)
{
auto x = dynamic_cast<luci::CircleConst *>(node->x());
auto y = dynamic_cast<luci::CircleConst *>(node->y());

if (x != nullptr && y == nullptr)
{
pred_node = loco::must_cast<luci::CircleNode *>(node->y());
const_node = x;
}
else if (x == nullptr && y != nullptr)
{
pred_node = loco::must_cast<luci::CircleNode *>(node->x());
const_node = y;
}
else
{
// Ignore if MUL does not have a const_node input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Ignore if MUL does not have a const_node input.
// Ignore if DIV does not have a const_node input.

return false;
}

if (pred_node->rank() != 4)
return false;

if (not broadcastable(const_node, node))
return false;

const_node = expand_to_rank_4(const_node);

return true;
}

// We assume ADD with const input is NCHW if,
// Input shape: (N, C, H, W)
// Output shape: (N, C, H, W)
Expand Down Expand Up @@ -872,6 +906,58 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}

bool visit(luci::CircleDiv *node)
{
LOGGER(l);

luci::CircleNode *pred_node = nullptr;
luci::CircleConst *constant = nullptr;

if (is_NCHW_with_const(node, pred_node, constant))
{
assert(constant->rank() == 4); // FIX is_NCHW_with_const unless
auto nhwc_const = create_NHWC_from_NCHW(constant);
if (nhwc_const == nullptr)
return false;
node->y(nhwc_const);

auto pre_trans = create_pre_transpose(node);
pre_trans->a(pred_node);
node->x(pre_trans);
}
else if (constant == nullptr)
{
// Only support for input rank 4
auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
if (input_x->rank() != 4)
return false;
auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
if (input_y->rank() != 4)
return false;

auto pre_trans_x = create_pre_transpose(node);
pre_trans_x->a(input_x);
node->x(pre_trans_x);

auto pre_trans_y = create_pre_transpose(node);
pre_trans_y->a(input_y);
node->y(pre_trans_y);
}
else
{
return false;
}

// Do shape inference for this node again.
node->shape_status(luci::ShapeStatus::UNDEFINED);

auto post_trans = create_post_transpose(node);
loco::replace(node).with(post_trans);

post_trans->a(node);
return true;
}

bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }

bool visit(luci::CircleGelu *node) { return convert_unary_features<luci::CircleGelu>(node); }
Expand Down Expand Up @@ -1578,6 +1664,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
// tflite/circle assumes the last channel is always axis
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
case luci::CircleOpcode::DIV:
case luci::CircleOpcode::ELU:
case luci::CircleOpcode::GELU:
case luci::CircleOpcode::LEAKY_RELU:
Expand Down
77 changes: 77 additions & 0 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,54 @@ class ConcatenationGraph final : public SimpleGraph
luci::CircleConst *input2 = nullptr;
};

class DivGraph final : public SimpleGraph
{
protected:
loco::Node *insertGraphBody(loco::Node *input) override
{
div = g.nodes()->create<luci::CircleDiv>();
constant = g.nodes()->create<luci::CircleConst>();

div->dtype(loco::DataType::FLOAT32);
constant->dtype(loco::DataType::FLOAT32);

uint32_t channel_size = 16;
div->shape({1, channel_size, 4, 4});
constant->shape({1, channel_size, 1, 1});

constant->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
{
constant->at<loco::DataType::FLOAT32>(i) = i;
}

div->x(input);
div->y(constant);

div->name("div");
constant->name("constant");

return div;
}

public:
void update_const_shape_to_nchw(void)
{
uint32_t channel_size = 16;
constant->shape({1, channel_size, 4, 4});

constant->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
for (uint32_t i = 0; i < channel_size; i++)
{
constant->at<loco::DataType::FLOAT32>(i) = i;
}
}

public:
luci::CircleDiv *div = nullptr;
luci::CircleConst *constant = nullptr;
};

class EluGraph final : public SimpleGraph
{
protected:
Expand Down Expand Up @@ -1382,6 +1430,35 @@ TEST(ConvertNCHWToNHWC, Concatenation)
EXPECT_EQ(3, g.concat->axis());
}

TEST(ConvertNCHWToNHWC, Div)
{
DivGraph g;
g.init();

run_phase(&g.g, false, false);

auto input_succs = loco::succs(g.input);
EXPECT_EQ(1, input_succs.size());
check_post_trans(*input_succs.begin());

check_pre_trans(g.div->x());

auto div_succs = loco::succs(g.div);
EXPECT_EQ(1, div_succs.size());
check_post_trans(*div_succs.begin());

uint32_t channel_size = 16;
auto new_constant = dynamic_cast<luci::CircleConst *>(g.div->y());
EXPECT_NE(nullptr, new_constant);
EXPECT_EQ(4, new_constant->rank());
EXPECT_EQ(1, new_constant->dim(0).value());
EXPECT_EQ(1, new_constant->dim(1).value());
EXPECT_EQ(1, new_constant->dim(2).value());
EXPECT_EQ(channel_size, new_constant->dim(3).value());

check_pre_trans(g.output->from());
}

TEST(ConvertNCHWToNHWC, Elu)
{
EluGraph g;
Expand Down