Skip to content

Commit

Permalink
[luci/pass] Support Div in ConvertNCHWToNHWC (#13525)
Browse files Browse the repository at this point in the history
* [luci/pass] Support Div in ConvertNCHWToNHWC

This supports Div in ConvertNCHWToNHWC.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>

* Update comment
  • Loading branch information
jinevening authored Jul 29, 2024
1 parent 3a8110a commit 237bf23
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
87 changes: 87 additions & 0 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,40 @@ bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_nod
return true;
}

// TODO Merge this function with other elementwise Ops
bool is_NCHW_with_const(const luci::CircleDiv *node, luci::CircleNode *&pred_node,
luci::CircleConst *&const_node)
{
auto x = dynamic_cast<luci::CircleConst *>(node->x());
auto y = dynamic_cast<luci::CircleConst *>(node->y());

if (x != nullptr && y == nullptr)
{
pred_node = loco::must_cast<luci::CircleNode *>(node->y());
const_node = x;
}
else if (x == nullptr && y != nullptr)
{
pred_node = loco::must_cast<luci::CircleNode *>(node->x());
const_node = y;
}
else
{
// Ignore if DIV does not have a const_node input.
return false;
}

if (pred_node->rank() != 4)
return false;

if (not broadcastable(const_node, node))
return false;

const_node = expand_to_rank_4(const_node);

return true;
}

// We assume ADD with const input is NCHW if,
// Input shape: (N, C, H, W)
// Output shape: (N, C, H, W)
Expand Down Expand Up @@ -832,6 +866,58 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}

bool visit(luci::CircleDiv *node)
{
LOGGER(l);

luci::CircleNode *pred_node = nullptr;
luci::CircleConst *constant = nullptr;

if (is_NCHW_with_const(node, pred_node, constant))
{
assert(constant->rank() == 4); // FIX is_NCHW_with_const unless
auto nhwc_const = create_NHWC_from_NCHW(constant);
if (nhwc_const == nullptr)
return false;
node->y(nhwc_const);

auto pre_trans = create_pre_transpose(node);
pre_trans->a(pred_node);
node->x(pre_trans);
}
else if (constant == nullptr)
{
// Only support for input rank 4
auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
if (input_x->rank() != 4)
return false;
auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
if (input_y->rank() != 4)
return false;

auto pre_trans_x = create_pre_transpose(node);
pre_trans_x->a(input_x);
node->x(pre_trans_x);

auto pre_trans_y = create_pre_transpose(node);
pre_trans_y->a(input_y);
node->y(pre_trans_y);
}
else
{
return false;
}

// Do shape inference for this node again.
node->shape_status(luci::ShapeStatus::UNDEFINED);

auto post_trans = create_post_transpose(node);
loco::replace(node).with(post_trans);

post_trans->a(node);
return true;
}

bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }

bool visit(luci::CircleGelu *node) { return convert_unary_features<luci::CircleGelu>(node); }
Expand Down Expand Up @@ -1538,6 +1624,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
// tflite/circle assumes the last channel is always axis
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
case luci::CircleOpcode::DIV:
case luci::CircleOpcode::ELU:
case luci::CircleOpcode::GELU:
case luci::CircleOpcode::LEAKY_RELU:
Expand Down
77 changes: 77 additions & 0 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,54 @@ class ConcatenationGraph final : public SimpleGraph
luci::CircleConst *input2 = nullptr;
};

class DivGraph final : public SimpleGraph
{
protected:
loco::Node *insertGraphBody(loco::Node *input) override
{
div = g.nodes()->create<luci::CircleDiv>();
constant = g.nodes()->create<luci::CircleConst>();

div->dtype(loco::DataType::FLOAT32);
constant->dtype(loco::DataType::FLOAT32);

uint32_t channel_size = 16;
div->shape({1, channel_size, 4, 4});
constant->shape({1, channel_size, 1, 1});

constant->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
{
constant->at<loco::DataType::FLOAT32>(i) = i;
}

div->x(input);
div->y(constant);

div->name("div");
constant->name("constant");

return div;
}

public:
void update_const_shape_to_nchw(void)
{
uint32_t channel_size = 16;
constant->shape({1, channel_size, 4, 4});

constant->size<loco::DataType::FLOAT32>(channel_size * 4 * 4);
for (uint32_t i = 0; i < channel_size; i++)
{
constant->at<loco::DataType::FLOAT32>(i) = i;
}
}

public:
luci::CircleDiv *div = nullptr;
luci::CircleConst *constant = nullptr;
};

class EluGraph final : public SimpleGraph
{
protected:
Expand Down Expand Up @@ -1382,6 +1430,35 @@ TEST(ConvertNCHWToNHWC, Concatenation)
EXPECT_EQ(3, g.concat->axis());
}

TEST(ConvertNCHWToNHWC, Div)
{
DivGraph g;
g.init();

run_phase(&g.g, false, false);

auto input_succs = loco::succs(g.input);
EXPECT_EQ(1, input_succs.size());
check_post_trans(*input_succs.begin());

check_pre_trans(g.div->x());

auto div_succs = loco::succs(g.div);
EXPECT_EQ(1, div_succs.size());
check_post_trans(*div_succs.begin());

uint32_t channel_size = 16;
auto new_constant = dynamic_cast<luci::CircleConst *>(g.div->y());
EXPECT_NE(nullptr, new_constant);
EXPECT_EQ(4, new_constant->rank());
EXPECT_EQ(1, new_constant->dim(0).value());
EXPECT_EQ(1, new_constant->dim(1).value());
EXPECT_EQ(1, new_constant->dim(2).value());
EXPECT_EQ(channel_size, new_constant->dim(3).value());

check_pre_trans(g.output->from());
}

TEST(ConvertNCHWToNHWC, Elu)
{
EluGraph g;
Expand Down

0 comments on commit 237bf23

Please sign in to comment.