Skip to content

Commit

Permalink
[luci] Support scalar axis in SubstituteExpandDimsToReshapePass (#14464)
Browse files Browse the repository at this point in the history
This supports scalar axis in SubstituteExpandDimsToReshapePass.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Dec 17, 2024
1 parent 10dcfe7 commit a2726bf
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
19 changes: 19 additions & 0 deletions compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ int32_t unknown_dim_count(luci::CircleNode *node)
*/
int32_t value_from_circle_const(const luci::CircleConst *node, uint32_t idx)
{
// Scalar case: rank 0, only one element in CircleConst
if (node->rank() == 0)
{
if (node->dtype() == loco::DataType::S64)
{
assert(node->size<loco::DataType::S64>() == 1); // FIX_ME_UNLESS
return static_cast<int32_t>(node->at<loco::DataType::S64>(0));
}
else if (node->dtype() == loco::DataType::S32)
{
assert(node->size<loco::DataType::S32>() == 1); // FIX_ME_UNLESS
return node->at<loco::DataType::S32>(0);
}
else
{
throw std::runtime_error("Unsupported dtype");
}
}

assert(node->rank() == 1 && node->dim(0).value() > idx);
assert(node->dtype() == loco::DataType::S64 || node->dtype() == loco::DataType::S32);

Expand Down
26 changes: 26 additions & 0 deletions compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ class ExpandDimsWithConstAxisGraph : public PassTestGraph
_output->from(_expand_dims);
}

void scalar_axis(int32_t axis)
{
_axis = luci::create_const_node<int32_t>(g(), loco::DataType::S32, {}, axis);
_axis->name("axis");
}

protected:
luci::CircleConst *_axis = nullptr;

Expand Down Expand Up @@ -247,6 +253,26 @@ TEST_F(SubstituteExpandDimsToReshapeTest, simple_with_expand_dims_2)
ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(3));
}

TEST_F(SubstituteExpandDimsToReshapeTest, scalar_axis)
{
_graph.init({16, 3, 1}, {16, 3, 1, 1}, 2);

_graph.scalar_axis(2);

run_pass();

auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
auto expand_dims = dynamic_cast<luci::CircleExpandDims *>(_graph.output()->from());
ASSERT_NE(nullptr, reshape);
ASSERT_EQ(nullptr, expand_dims);
auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape());
ASSERT_EQ(4, reshape_shape->size<loco::DataType::S32>());
ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(0));
ASSERT_EQ(3, reshape_shape->at<loco::DataType::S32>(1));
ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(2));
ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(3));
}

TEST_F(SubstituteExpandDimsToReshapeTest, invalid_axis_NEG)
{
_graph.init({1, 2, 3}, {1, 2, 3}, 5);
Expand Down

0 comments on commit a2726bf

Please sign in to comment.