diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp index b44a0a46aee..01c8ba8380d 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp @@ -609,6 +609,40 @@ bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_nod return true; } +// TODO Merge this function with other elementwise Ops +bool is_NCHW_with_const(const luci::CircleDiv *node, luci::CircleNode *&pred_node, + luci::CircleConst *&const_node) +{ + auto x = dynamic_cast(node->x()); + auto y = dynamic_cast(node->y()); + + if (x != nullptr && y == nullptr) + { + pred_node = loco::must_cast(node->y()); + const_node = x; + } + else if (x == nullptr && y != nullptr) + { + pred_node = loco::must_cast(node->x()); + const_node = y; + } + else + { + // Ignore if DIV does not have a const_node input. + return false; + } + + if (pred_node->rank() != 4) + return false; + + if (not broadcastable(const_node, node)) + return false; + + const_node = expand_to_rank_4(const_node); + + return true; +} + // We assume ADD with const input is NCHW if, // Input shape: (N, C, H, W) // Output shape: (N, C, H, W) @@ -872,6 +906,58 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor return true; } + bool visit(luci::CircleDiv *node) + { + LOGGER(l); + + luci::CircleNode *pred_node = nullptr; + luci::CircleConst *constant = nullptr; + + if (is_NCHW_with_const(node, pred_node, constant)) + { + assert(constant->rank() == 4); // FIX is_NCHW_with_const unless + auto nhwc_const = create_NHWC_from_NCHW(constant); + if (nhwc_const == nullptr) + return false; + node->y(nhwc_const); + + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->x(pre_trans); + } + else if (constant == nullptr) + { + // Only support for input rank 4 + auto input_x = loco::must_cast(node->x()); + if (input_x->rank() != 4) + return false; + auto input_y = loco::must_cast(node->y()); + if (input_y->rank() != 4) + return false; + + auto pre_trans_x = create_pre_transpose(node); + pre_trans_x->a(input_x); + node->x(pre_trans_x); + + auto pre_trans_y = create_pre_transpose(node); + pre_trans_y->a(input_y); + node->y(pre_trans_y); + } + else + { + return false; + } + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + return true; + } + bool visit(luci::CircleElu *node) { return convert_unary_features(node); } bool visit(luci::CircleGelu *node) { return convert_unary_features(node); } @@ -1578,6 +1664,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) // tflite/circle assumes the last channel is always axis case luci::CircleOpcode::ADD: case luci::CircleOpcode::CONCATENATION: + case luci::CircleOpcode::DIV: case luci::CircleOpcode::ELU: case luci::CircleOpcode::GELU: case luci::CircleOpcode::LEAKY_RELU: diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index 82f3e1f3d94..4254f7fe42d 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -435,6 +435,54 @@ class ConcatenationGraph final : public SimpleGraph luci::CircleConst *input2 = nullptr; }; +class DivGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + div = g.nodes()->create(); + constant = g.nodes()->create(); + + div->dtype(loco::DataType::FLOAT32); + constant->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + div->shape({1, channel_size, 4, 4}); + constant->shape({1, channel_size, 1, 1}); + + constant->size(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + { + constant->at(i) = i; + } + + div->x(input); + div->y(constant); + + div->name("div"); + constant->name("constant"); + + return div; + } + +public: + void update_const_shape_to_nchw(void) + { + uint32_t channel_size = 16; + constant->shape({1, channel_size, 4, 4}); + + constant->size(channel_size * 4 * 4); + for (uint32_t i = 0; i < channel_size; i++) + { + constant->at(i) = i; + } + } + +public: + luci::CircleDiv *div = nullptr; + luci::CircleConst *constant = nullptr; +}; + class EluGraph final : public SimpleGraph { protected: @@ -1382,6 +1430,35 @@ TEST(ConvertNCHWToNHWC, Concatenation) EXPECT_EQ(3, g.concat->axis()); } +TEST(ConvertNCHWToNHWC, Div) +{ + DivGraph g; + g.init(); + + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); + + check_pre_trans(g.div->x()); + + auto div_succs = loco::succs(g.div); + EXPECT_EQ(1, div_succs.size()); + check_post_trans(*div_succs.begin()); + + uint32_t channel_size = 16; + auto new_constant = dynamic_cast(g.div->y()); + EXPECT_NE(nullptr, new_constant); + EXPECT_EQ(4, new_constant->rank()); + EXPECT_EQ(1, new_constant->dim(0).value()); + EXPECT_EQ(1, new_constant->dim(1).value()); + EXPECT_EQ(1, new_constant->dim(2).value()); + EXPECT_EQ(channel_size, new_constant->dim(3).value()); + + check_pre_trans(g.output->from()); +} + TEST(ConvertNCHWToNHWC, Elu) { EluGraph g;