diff --git a/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.cpp b/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.cpp index fc88b189481..6bf2a7c6f60 100644 --- a/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.cpp @@ -67,6 +67,25 @@ int32_t unknown_dim_count(luci::CircleNode *node) */ int32_t value_from_circle_const(const luci::CircleConst *node, uint32_t idx) { + // Scalar case: rank 0, only one element in CircleConst + if (node->rank() == 0) + { + if (node->dtype() == loco::DataType::S64) + { + assert(node->size() == 1); // FIX_ME_UNLESS + return static_cast(node->at(0)); + } + else if (node->dtype() == loco::DataType::S32) + { + assert(node->size() == 1); // FIX_ME_UNLESS + return node->at(0); + } + else + { + throw std::runtime_error("Unsupported dtype"); + } + } + assert(node->rank() == 1 && node->dim(0).value() > idx); assert(node->dtype() == loco::DataType::S64 || node->dtype() == loco::DataType::S32); diff --git a/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.test.cpp b/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.test.cpp index 046fa82caf8..41f1bb6b498 100644 --- a/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.test.cpp +++ b/compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.test.cpp @@ -132,6 +132,12 @@ class ExpandDimsWithConstAxisGraph : public PassTestGraph _output->from(_expand_dims); } + void scalar_axis(int32_t axis) + { + _axis = luci::create_const_node(g(), loco::DataType::S32, {}, axis); + _axis->name("axis"); + } + protected: luci::CircleConst *_axis = nullptr; @@ -247,6 +253,26 @@ TEST_F(SubstituteExpandDimsToReshapeTest, simple_with_expand_dims_2) ASSERT_EQ(1, reshape_shape->at(3)); } +TEST_F(SubstituteExpandDimsToReshapeTest, scalar_axis) +{ + _graph.init({16, 3, 1}, {16, 3, 1, 1}, 2); + + _graph.scalar_axis(2); + + run_pass(); + + auto reshape = dynamic_cast(_graph.output()->from()); + auto expand_dims = dynamic_cast(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, expand_dims); + auto reshape_shape = loco::must_cast(reshape->shape()); + ASSERT_EQ(4, reshape_shape->size()); + ASSERT_EQ(16, reshape_shape->at(0)); + ASSERT_EQ(3, reshape_shape->at(1)); + ASSERT_EQ(1, reshape_shape->at(2)); + ASSERT_EQ(1, reshape_shape->at(3)); +} + TEST_F(SubstituteExpandDimsToReshapeTest, invalid_axis_NEG) { _graph.init({1, 2, 3}, {1, 2, 3}, 5);