diff --git a/compiler/luci/pass/src/ExpandBroadcastConstPass.cpp b/compiler/luci/pass/src/ExpandBroadcastConstPass.cpp index 25fb9f171eb..ceaacc52f1a 100644 --- a/compiler/luci/pass/src/ExpandBroadcastConstPass.cpp +++ b/compiler/luci/pass/src/ExpandBroadcastConstPass.cpp @@ -47,13 +47,18 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl return nullptr; } + if (successor->rank() == 1 || successor->rank() > 4) + { + WARN(l) << "NYI: Only 2D/3D/4D tensor broadcast removal is supported"; + return nullptr; + } + auto constant = node->graph()->nodes()->create(); constant->name(node->name()); constant->dtype(node->dtype()); constant->rank(node->rank()); constant->shape_status(luci::ShapeStatus::VALID); - uint32_t node_size = node->size(); uint32_t constant_size = 1; for (uint32_t i = 0; i < successor->rank(); ++i) { @@ -65,10 +70,34 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl auto const node_data = &node->at(0); auto const constant_data = &constant->at(0); - auto const successor_depth = successor->dim(successor->rank() - 1).value(); - for (uint32_t d = 0; d < successor_depth; ++d) - std::copy(node_data, node_data + node_size, constant_data + d * node_size); - + if (successor->rank() == 2) + { + auto const N = successor->dim(successor->rank() - 2).value(); + auto const D = successor->dim(successor->rank() - 1).value(); + for (uint32_t n = 0; n < N; ++n) + std::fill_n(constant_data + n * D, D, node_data[n]); + } + else if (successor->rank() == 3) + { + auto const H = successor->dim(successor->rank() - 3).value(); + auto const W = successor->dim(successor->rank() - 2).value(); + auto const D = successor->dim(successor->rank() - 1).value(); + for (uint32_t h = 0; h < H; ++h) + for (uint32_t w = 0; w < W; ++w) + std::fill_n(constant_data + h * W * D + w * D, D, node_data[h * W + w]); + } + else if (successor->rank() == 4) + { + auto const N = successor->dim(successor->rank() - 4).value(); + auto const H = successor->dim(successor->rank() - 3).value(); + auto const W = successor->dim(successor->rank() - 2).value(); + auto const D = successor->dim(successor->rank() - 1).value(); + for (uint32_t n = 0; n < N; ++n) + for (uint32_t h = 0; h < H; ++h) + for (uint32_t w = 0; w < W; ++w) + std::fill_n(constant_data + n * H * W * D + h * W * D + w * D, D, + node_data[n * H * W + h * W + w]); + } return constant; } diff --git a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp index 5df1b72dcd1..1eb985f44e8 100644 --- a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp +++ b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp @@ -23,13 +23,16 @@ #include +/**************************************************************************** + * TESTS FOR RANK 2 + ****************************************************************************/ namespace { -class ExpandBroadcastConstTest : public ::testing::Test +class ExpandBroadcastConstRank2Graph { public: - ExpandBroadcastConstTest() + ExpandBroadcastConstRank2Graph() { _x = _g.nodes()->create(); _y = _g.nodes()->create(); @@ -38,27 +41,114 @@ class ExpandBroadcastConstTest : public ::testing::Test auto graph_input = _g.inputs()->create(); graph_input->dtype(loco::DataType::FLOAT32); - graph_input->shape({1, H, W, D}); + graph_input->shape({N, D}); _x->index(graph_input->index()); _x->dtype(graph_input->dtype()); - _x->shape({1, H, W, D}); + _x->shape({N, D}); + + _y->dtype(loco::DataType::FLOAT32); + _y->shape({N, 1}); + _y->size(N); auto graph_output = _g.outputs()->create(); graph_output->dtype(loco::DataType::FLOAT32); - graph_output->shape({1, H, W, D}); + graph_output->shape({N, D}); _output->index(graph_output->index()); _output->dtype(graph_output->dtype()); - _output->shape({1, H, W, D}); + _output->shape({N, D}); - _y->dtype(loco::DataType::FLOAT32); - _y->shape({1, H, W, 1}); - _y->size(16); + _add->dtype(loco::DataType::FLOAT32); + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + _add->x(_x); + _add->y(_y); + _add->shape({N, D}); + + _output->from(_add); + + _x->name("input"); + _output->name("output"); + } + +protected: + uint32_t const N = 4; + uint32_t const D = 3; + +protected: + loco::Graph _g; + luci::CircleAdd *_add = nullptr; + luci::CircleInput *_x = nullptr; + luci::CircleConst *_y = nullptr; + luci::CircleOutput *_output = nullptr; +}; + +class ExpandBroadcastRank2ConstTest : public ExpandBroadcastConstRank2Graph, public ::testing::Test +{ +public: + ExpandBroadcastRank2ConstTest() {} +}; +} // namespace + +TEST_F(ExpandBroadcastRank2ConstTest, remove_broadcast) +{ + for (uint32_t i = 0; i < N; ++i) + _y->at(i) = static_cast(i); + + luci::ExpandBroadcastConstPass pass; + ASSERT_TRUE(pass.run(&_g)); + + auto broadcasted_const = dynamic_cast(_add->y()); + ASSERT_NE(broadcasted_const, nullptr); + + EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32); + EXPECT_EQ(broadcasted_const->dim(0).value(), N); + EXPECT_EQ(broadcasted_const->dim(1).value(), D); + EXPECT_EQ(broadcasted_const->size(), N * D); + + for (uint32_t i = 0; i < N; ++i) + { + for (uint32_t d = 0; d < D; ++d) + { + EXPECT_NEAR(broadcasted_const->at(i * D + d), static_cast(i), + std::numeric_limits::min()); + } + } +} + +/**************************************************************************** + * TESTS FOR RANK 4 + ****************************************************************************/ + +namespace +{ +class ExpandBroadcastConstRank4Graph +{ +public: + ExpandBroadcastConstRank4Graph() + { + _x = _g.nodes()->create(); + _y = _g.nodes()->create(); + _add = _g.nodes()->create(); + _output = _g.nodes()->create(); + + auto graph_input = _g.inputs()->create(); + graph_input->dtype(loco::DataType::FLOAT32); + graph_input->shape({N, H, W, D}); + _x->index(graph_input->index()); + _x->dtype(graph_input->dtype()); + _x->shape({N, H, W, D}); + + auto graph_output = _g.outputs()->create(); + graph_output->dtype(loco::DataType::FLOAT32); + graph_output->shape({N, H, W, D}); + _output->index(graph_output->index()); + _output->dtype(graph_output->dtype()); + _output->shape({N, H, W, D}); _add->dtype(loco::DataType::FLOAT32); _add->fusedActivationFunction(luci::FusedActFunc::NONE); _add->x(_x); _add->y(_y); - _add->shape({1, H, W, D}); + _add->shape({N, H, W, D}); _output->from(_add); @@ -67,6 +157,7 @@ class ExpandBroadcastConstTest : public ::testing::Test } protected: + uint32_t const N = 2; uint32_t const H = 4; uint32_t const W = 4; uint32_t const D = 3; @@ -79,18 +170,30 @@ class ExpandBroadcastConstTest : public ::testing::Test luci::CircleOutput *_output = nullptr; }; +class ExpandBroadcastRank4ConstTest1 : public ExpandBroadcastConstRank4Graph, public ::testing::Test +{ +public: + ExpandBroadcastRank4ConstTest1() + { + _y->dtype(loco::DataType::FLOAT32); + _y->shape({N, H, W, 1}); + _y->size(N * H * W * 1); + } +}; + +// TODO: Add more tests for Rank4 with different broadcasting dimensions } // namespace -TEST_F(ExpandBroadcastConstTest, name) +TEST_F(ExpandBroadcastRank4ConstTest1, name) { luci::ExpandBroadcastConstPass pass; auto const name = pass.name(); ASSERT_NE(nullptr, name); } -TEST_F(ExpandBroadcastConstTest, remove_broadcast) +TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast) { - for (uint32_t i = 0; i < H * W; ++i) + for (uint32_t i = 0; i < N * H * W; ++i) _y->at(i) = static_cast(i); luci::ExpandBroadcastConstPass pass; @@ -100,26 +203,27 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast) ASSERT_NE(broadcasted_const, nullptr); EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32); + EXPECT_EQ(broadcasted_const->dim(0).value(), N); EXPECT_EQ(broadcasted_const->dim(1).value(), H); EXPECT_EQ(broadcasted_const->dim(2).value(), W); EXPECT_EQ(broadcasted_const->dim(3).value(), D); - EXPECT_EQ(broadcasted_const->size(), H * W * D); + EXPECT_EQ(broadcasted_const->size(), N * H * W * D); - for (uint32_t i = 0; i < H * W; ++i) + for (uint32_t i = 0; i < N * H * W; ++i) { for (uint32_t d = 0; d < D; ++d) { - EXPECT_NEAR(broadcasted_const->at(i + H * W * d), - static_cast(i), std::numeric_limits::min()); + EXPECT_NEAR(broadcasted_const->at(i * D + d), static_cast(i), + std::numeric_limits::min()); } } } -TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors) +TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast_multiple_successors) { auto const circle_sqrt = _g.nodes()->create(); circle_sqrt->dtype(loco::DataType::FLOAT32); - circle_sqrt->shape({1, H, W, 1}); + circle_sqrt->shape({N, H, W, 1}); circle_sqrt->x(_y); luci::ExpandBroadcastConstPass pass; @@ -131,19 +235,19 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors) ASSERT_NE(broadcasted_const, nullptr); EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32); EXPECT_EQ(broadcasted_const->dim(3).value(), D); - EXPECT_EQ(broadcasted_const->size(), H * W * D); + EXPECT_EQ(broadcasted_const->size(), N * H * W * D); // Check if another successor's node was left intact ASSERT_NE(original_const, nullptr); EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32); EXPECT_EQ(original_const->dim(3).value(), 1); - EXPECT_EQ(original_const->size(), H * W * 1); + EXPECT_EQ(original_const->size(), N * H * W * 1); } -TEST_F(ExpandBroadcastConstTest, broadcast_impossible_NEG) +TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_impossible_NEG) { - _y->shape({1, H, W, 2}); - _y->size(H * W * (D - 1)); + _y->shape({N, H, W, D + 1}); + _y->size(N * H * W * (D + 1)); luci::ExpandBroadcastConstPass pass; ASSERT_FALSE(pass.run(&_g));