Skip to content

Commit

Permalink
[luci/pass] Fix ExpandBroadcastConstPass
Browse files Browse the repository at this point in the history
Let's fix a bug in ExpandBroadcastConstPass.
  • Loading branch information
dayo09 committed Jan 14, 2025
1 parent b7d4e5e commit 8585293
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 20 deletions.
20 changes: 16 additions & 4 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,22 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
auto const node_data = &node->at<loco::DataType::FLOAT32>(0);
auto const constant_data = &constant->at<loco::DataType::FLOAT32>(0);

auto const successor_depth = successor->dim(successor->rank() - 1).value();
for (uint32_t d = 0; d < successor_depth; ++d)
std::copy(node_data, node_data + node_size, constant_data + d * node_size);

if (successor->rank() == 2)
{
auto const N = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t n = 0; n < N; ++n)
std::fill_n(constant_data + n * D, D, node_data[n]);
}
if (successor->rank() >= 3)
{
auto const H = successor->dim(successor->rank() - 3).value();
auto const W = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
std::fill_n(constant_data + h * W * D + w * D, D, node_data[h * W + w]);
}
return constant;
}

Expand Down
135 changes: 119 additions & 16 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@

#include <gtest/gtest.h>

/****************************************************************************
* TESTS FOR RANK 2
****************************************************************************/
namespace
{

class ExpandBroadcastConstTest : public ::testing::Test
class ExpandBroadcastConstRank2Graph
{
public:
ExpandBroadcastConstTest()
ExpandBroadcastConstRank2Graph()
{
_x = _g.nodes()->create<luci::CircleInput>();
_y = _g.nodes()->create<luci::CircleConst>();
Expand All @@ -38,27 +41,114 @@ class ExpandBroadcastConstTest : public ::testing::Test

auto graph_input = _g.inputs()->create();
graph_input->dtype(loco::DataType::FLOAT32);
graph_input->shape({1, H, W, D});
graph_input->shape({N, D});
_x->index(graph_input->index());
_x->dtype(graph_input->dtype());
_x->shape({1, H, W, D});
_x->shape({N, D});

_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, 1});
_y->size<loco::DataType::FLOAT32>(N);

auto graph_output = _g.outputs()->create();
graph_output->dtype(loco::DataType::FLOAT32);
graph_output->shape({1, H, W, D});
graph_output->shape({N, D});
_output->index(graph_output->index());
_output->dtype(graph_output->dtype());
_output->shape({1, H, W, D});
_output->shape({N, D});

_y->dtype(loco::DataType::FLOAT32);
_y->shape({1, H, W, 1});
_y->size<loco::DataType::FLOAT32>(16);
_add->dtype(loco::DataType::FLOAT32);
_add->fusedActivationFunction(luci::FusedActFunc::NONE);
_add->x(_x);
_add->y(_y);
_add->shape({N, D});

_output->from(_add);

_x->name("input");
_output->name("output");
}

protected:
uint32_t const N = 4;
uint32_t const D = 3;

protected:
loco::Graph _g;
luci::CircleAdd *_add = nullptr;
luci::CircleInput *_x = nullptr;
luci::CircleConst *_y = nullptr;
luci::CircleOutput *_output = nullptr;
};

class ExpandBroadcastRank2ConstTest : public ExpandBroadcastConstRank2Graph, public ::testing::Test
{
public:
ExpandBroadcastRank2ConstTest() {}
};
} // namespace

TEST_F(ExpandBroadcastRank2ConstTest, remove_broadcast)
{
for (uint32_t i = 0; i < N; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);

luci::ExpandBroadcastConstPass pass;
ASSERT_TRUE(pass.run(&_g));

auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
ASSERT_NE(broadcasted_const, nullptr);

EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(0).value(), N);
EXPECT_EQ(broadcasted_const->dim(1).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * D);

for (uint32_t i = 0; i < N; ++i)
{
for (uint32_t d = 0; d < D; ++d)
{
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i * D + d), static_cast<float>(i),
std::numeric_limits<float>::min());
}
}
}

/****************************************************************************
* TESTS FOR RANK 4
****************************************************************************/

namespace
{
class ExpandBroadcastConstRank4Graph
{
public:
ExpandBroadcastConstRank4Graph()
{
_x = _g.nodes()->create<luci::CircleInput>();
_y = _g.nodes()->create<luci::CircleConst>();
_add = _g.nodes()->create<luci::CircleAdd>();
_output = _g.nodes()->create<luci::CircleOutput>();

auto graph_input = _g.inputs()->create();
graph_input->dtype(loco::DataType::FLOAT32);
graph_input->shape({N, H, W, D});
_x->index(graph_input->index());
_x->dtype(graph_input->dtype());
_x->shape({N, H, W, D});

auto graph_output = _g.outputs()->create();
graph_output->dtype(loco::DataType::FLOAT32);
graph_output->shape({N, H, W, D});
_output->index(graph_output->index());
_output->dtype(graph_output->dtype());
_output->shape({N, H, W, D});

_add->dtype(loco::DataType::FLOAT32);
_add->fusedActivationFunction(luci::FusedActFunc::NONE);
_add->x(_x);
_add->y(_y);
_add->shape({1, H, W, D});
_add->shape({N, H, W, D});

_output->from(_add);

Expand All @@ -67,6 +157,7 @@ class ExpandBroadcastConstTest : public ::testing::Test
}

protected:
uint32_t const N = 1;
uint32_t const H = 4;
uint32_t const W = 4;
uint32_t const D = 3;
Expand All @@ -79,16 +170,28 @@ class ExpandBroadcastConstTest : public ::testing::Test
luci::CircleOutput *_output = nullptr;
};

class ExpandBroadcastRank4ConstTest1 : public ExpandBroadcastConstRank4Graph, public ::testing::Test
{
public:
ExpandBroadcastRank4ConstTest1()
{
_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, H, W, 1});
_y->size<loco::DataType::FLOAT32>(16);
}
};

// TODO: Add more tests for Rank4 with different broadcasting dimensions
} // namespace

TEST_F(ExpandBroadcastConstTest, name)
TEST_F(ExpandBroadcastRank4ConstTest1, name)
{
luci::ExpandBroadcastConstPass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(ExpandBroadcastConstTest, remove_broadcast)
TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast)
{
for (uint32_t i = 0; i < H * W; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);
Expand All @@ -109,13 +212,13 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast)
{
for (uint32_t d = 0; d < D; ++d)
{
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i + H * W * d),
static_cast<float>(i), std::numeric_limits<float>::min());
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i * D + d), static_cast<float>(i),
std::numeric_limits<float>::min());
}
}
}

TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast_multiple_successors)
{
auto const circle_sqrt = _g.nodes()->create<luci::CircleSqrt>();
circle_sqrt->dtype(loco::DataType::FLOAT32);
Expand All @@ -140,7 +243,7 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), H * W * 1);
}

TEST_F(ExpandBroadcastConstTest, broadcast_impossible_NEG)
TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_impossible_NEG)
{
_y->shape({1, H, W, 2});
_y->size<loco::DataType::FLOAT32>(H * W * (D - 1));
Expand Down

0 comments on commit 8585293

Please sign in to comment.