Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Fix ExpandBroadcastConstPass #14560

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
constant->rank(node->rank());
constant->shape_status(luci::ShapeStatus::VALID);

uint32_t node_size = node->size<loco::DataType::FLOAT32>();
uint32_t constant_size = 1;
for (uint32_t i = 0; i < successor->rank(); ++i)
{
Expand All @@ -65,10 +64,35 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
auto const node_data = &node->at<loco::DataType::FLOAT32>(0);
auto const constant_data = &constant->at<loco::DataType::FLOAT32>(0);

auto const successor_depth = successor->dim(successor->rank() - 1).value();
for (uint32_t d = 0; d < successor_depth; ++d)
std::copy(node_data, node_data + node_size, constant_data + d * node_size);

assert(successor->rank() >= 2 && successor->rank() <= 4);
if (successor->rank() == 2)
{
auto const N = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t n = 0; n < N; ++n)
std::fill_n(constant_data + n * D, D, node_data[n]);
}
else if (successor->rank() == 3)
{
auto const H = successor->dim(successor->rank() - 3).value();
auto const W = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
std::fill_n(constant_data + h * W * D + w * D, D, node_data[h * W + w]);
}
else if (successor->rank() == 4)
{
auto const N = successor->dim(successor->rank() - 4).value();
auto const H = successor->dim(successor->rank() - 3).value();
auto const W = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t n = 0; n < N; ++n)
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
std::fill_n(constant_data + n * H * W * D + h * W * D + w * D, D,
node_data[n * H * W + h * W + w]);
}
return constant;
}

Expand Down
152 changes: 128 additions & 24 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@

#include <gtest/gtest.h>

/****************************************************************************
* TESTS FOR RANK 2
****************************************************************************/
namespace
{

class ExpandBroadcastConstTest : public ::testing::Test
class ExpandBroadcastConstRank2Graph
{
public:
ExpandBroadcastConstTest()
ExpandBroadcastConstRank2Graph()
{
_x = _g.nodes()->create<luci::CircleInput>();
_y = _g.nodes()->create<luci::CircleConst>();
Expand All @@ -38,27 +41,114 @@ class ExpandBroadcastConstTest : public ::testing::Test

auto graph_input = _g.inputs()->create();
graph_input->dtype(loco::DataType::FLOAT32);
graph_input->shape({1, H, W, D});
graph_input->shape({N, D});
_x->index(graph_input->index());
_x->dtype(graph_input->dtype());
_x->shape({1, H, W, D});
_x->shape({N, D});

_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, 1});
_y->size<loco::DataType::FLOAT32>(N);

auto graph_output = _g.outputs()->create();
graph_output->dtype(loco::DataType::FLOAT32);
graph_output->shape({1, H, W, D});
graph_output->shape({N, D});
_output->index(graph_output->index());
_output->dtype(graph_output->dtype());
_output->shape({1, H, W, D});
_output->shape({N, D});

_y->dtype(loco::DataType::FLOAT32);
_y->shape({1, H, W, 1});
_y->size<loco::DataType::FLOAT32>(16);
_add->dtype(loco::DataType::FLOAT32);
_add->fusedActivationFunction(luci::FusedActFunc::NONE);
_add->x(_x);
_add->y(_y);
_add->shape({N, D});

_output->from(_add);

_x->name("input");
_output->name("output");
}

protected:
uint32_t const N = 4;
uint32_t const D = 3;

protected:
loco::Graph _g;
luci::CircleAdd *_add = nullptr;
luci::CircleInput *_x = nullptr;
luci::CircleConst *_y = nullptr;
luci::CircleOutput *_output = nullptr;
};

class ExpandBroadcastRank2ConstTest : public ExpandBroadcastConstRank2Graph, public ::testing::Test
{
public:
ExpandBroadcastRank2ConstTest() {}
};
} // namespace

TEST_F(ExpandBroadcastRank2ConstTest, remove_broadcast)
{
for (uint32_t i = 0; i < N; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);

luci::ExpandBroadcastConstPass pass;
ASSERT_TRUE(pass.run(&_g));

auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
ASSERT_NE(broadcasted_const, nullptr);

EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(0).value(), N);
EXPECT_EQ(broadcasted_const->dim(1).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * D);

for (uint32_t i = 0; i < N; ++i)
{
for (uint32_t d = 0; d < D; ++d)
{
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i * D + d), static_cast<float>(i),
std::numeric_limits<float>::min());
}
}
}

/****************************************************************************
* TESTS FOR RANK 4
****************************************************************************/

namespace
{
class ExpandBroadcastConstRank4Graph
{
public:
ExpandBroadcastConstRank4Graph()
{
_x = _g.nodes()->create<luci::CircleInput>();
_y = _g.nodes()->create<luci::CircleConst>();
_add = _g.nodes()->create<luci::CircleAdd>();
_output = _g.nodes()->create<luci::CircleOutput>();

auto graph_input = _g.inputs()->create();
graph_input->dtype(loco::DataType::FLOAT32);
graph_input->shape({N, H, W, D});
_x->index(graph_input->index());
_x->dtype(graph_input->dtype());
_x->shape({N, H, W, D});

auto graph_output = _g.outputs()->create();
graph_output->dtype(loco::DataType::FLOAT32);
graph_output->shape({N, H, W, D});
_output->index(graph_output->index());
_output->dtype(graph_output->dtype());
_output->shape({N, H, W, D});

_add->dtype(loco::DataType::FLOAT32);
_add->fusedActivationFunction(luci::FusedActFunc::NONE);
_add->x(_x);
_add->y(_y);
_add->shape({1, H, W, D});
_add->shape({N, H, W, D});

_output->from(_add);

Expand All @@ -67,6 +157,7 @@ class ExpandBroadcastConstTest : public ::testing::Test
}

protected:
uint32_t const N = 2;
uint32_t const H = 4;
uint32_t const W = 4;
uint32_t const D = 3;
Expand All @@ -79,18 +170,30 @@ class ExpandBroadcastConstTest : public ::testing::Test
luci::CircleOutput *_output = nullptr;
};

class ExpandBroadcastRank4ConstTest1 : public ExpandBroadcastConstRank4Graph, public ::testing::Test
{
public:
ExpandBroadcastRank4ConstTest1()
{
_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, H, W, 1});
_y->size<loco::DataType::FLOAT32>(N * H * W * 1);
}
};

// TODO: Add more tests for Rank4 with different broadcasting dimensions
} // namespace

TEST_F(ExpandBroadcastConstTest, name)
TEST_F(ExpandBroadcastRank4ConstTest1, name)
{
luci::ExpandBroadcastConstPass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(ExpandBroadcastConstTest, remove_broadcast)
TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast)
{
for (uint32_t i = 0; i < H * W; ++i)
for (uint32_t i = 0; i < N * H * W; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);

luci::ExpandBroadcastConstPass pass;
Expand All @@ -100,26 +203,27 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast)
ASSERT_NE(broadcasted_const, nullptr);

EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(0).value(), N);
EXPECT_EQ(broadcasted_const->dim(1).value(), H);
EXPECT_EQ(broadcasted_const->dim(2).value(), W);
EXPECT_EQ(broadcasted_const->dim(3).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

for (uint32_t i = 0; i < H * W; ++i)
for (uint32_t i = 0; i < N * H * W; ++i)
{
for (uint32_t d = 0; d < D; ++d)
{
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i + H * W * d),
static_cast<float>(i), std::numeric_limits<float>::min());
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i * D + d), static_cast<float>(i),
std::numeric_limits<float>::min());
}
}
}

TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast_multiple_successors)
{
auto const circle_sqrt = _g.nodes()->create<luci::CircleSqrt>();
circle_sqrt->dtype(loco::DataType::FLOAT32);
circle_sqrt->shape({1, H, W, 1});
circle_sqrt->shape({N, H, W, 1});
circle_sqrt->x(_y);

luci::ExpandBroadcastConstPass pass;
Expand All @@ -131,19 +235,19 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
ASSERT_NE(broadcasted_const, nullptr);
EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(3).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

// Check if another successor's node was left intact
ASSERT_NE(original_const, nullptr);
EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(original_const->dim(3).value(), 1);
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), H * W * 1);
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), N * H * W * 1);
}

TEST_F(ExpandBroadcastConstTest, broadcast_impossible_NEG)
TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_impossible_NEG)
{
_y->shape({1, H, W, 2});
_y->size<loco::DataType::FLOAT32>(H * W * (D - 1));
_y->shape({N, H, W, D + 1});
_y->size<loco::DataType::FLOAT32>(N * H * W * (D + 1));

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
Expand Down