Skip to content

Commit

Permalink
[luci/pass] Fix ExpandBroadcastConstPass (Samsung#14560)
Browse files Browse the repository at this point in the history
* [luci/pass] Fix ExpandBroadcastConstPass

Let's fix a bug in ExpandBroadcastConstPass.

ONE-DCO-1.0-Signed-off-by: Dayoung Lee <[email protected]>

* Return nullptr for unsupported case
  • Loading branch information
dayo09 authored and chenyx113 committed Jan 18, 2025
1 parent 3db3feb commit 15d27ea
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 29 deletions.
39 changes: 34 additions & 5 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
return nullptr;
}

if (successor->rank() == 1 || successor->rank() > 4)
{
WARN(l) << "NYI: Only 2D/3D/4D tensor broadcast removal is supported";
return nullptr;
}

auto constant = node->graph()->nodes()->create<luci::CircleConst>();
constant->name(node->name());
constant->dtype(node->dtype());
constant->rank(node->rank());
constant->shape_status(luci::ShapeStatus::VALID);

uint32_t node_size = node->size<loco::DataType::FLOAT32>();
uint32_t constant_size = 1;
for (uint32_t i = 0; i < successor->rank(); ++i)
{
Expand All @@ -65,10 +70,34 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
auto const node_data = &node->at<loco::DataType::FLOAT32>(0);
auto const constant_data = &constant->at<loco::DataType::FLOAT32>(0);

auto const successor_depth = successor->dim(successor->rank() - 1).value();
for (uint32_t d = 0; d < successor_depth; ++d)
std::copy(node_data, node_data + node_size, constant_data + d * node_size);

if (successor->rank() == 2)
{
auto const N = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t n = 0; n < N; ++n)
std::fill_n(constant_data + n * D, D, node_data[n]);
}
else if (successor->rank() == 3)
{
auto const H = successor->dim(successor->rank() - 3).value();
auto const W = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
std::fill_n(constant_data + h * W * D + w * D, D, node_data[h * W + w]);
}
else if (successor->rank() == 4)
{
auto const N = successor->dim(successor->rank() - 4).value();
auto const H = successor->dim(successor->rank() - 3).value();
auto const W = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t n = 0; n < N; ++n)
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
std::fill_n(constant_data + n * H * W * D + h * W * D + w * D, D,
node_data[n * H * W + h * W + w]);
}
return constant;
}

Expand Down
152 changes: 128 additions & 24 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@

#include <gtest/gtest.h>

/****************************************************************************
* TESTS FOR RANK 2
****************************************************************************/
namespace
{

class ExpandBroadcastConstTest : public ::testing::Test
class ExpandBroadcastConstRank2Graph
{
public:
ExpandBroadcastConstTest()
ExpandBroadcastConstRank2Graph()
{
_x = _g.nodes()->create<luci::CircleInput>();
_y = _g.nodes()->create<luci::CircleConst>();
Expand All @@ -38,27 +41,114 @@ class ExpandBroadcastConstTest : public ::testing::Test

auto graph_input = _g.inputs()->create();
graph_input->dtype(loco::DataType::FLOAT32);
graph_input->shape({1, H, W, D});
graph_input->shape({N, D});
_x->index(graph_input->index());
_x->dtype(graph_input->dtype());
_x->shape({1, H, W, D});
_x->shape({N, D});

_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, 1});
_y->size<loco::DataType::FLOAT32>(N);

auto graph_output = _g.outputs()->create();
graph_output->dtype(loco::DataType::FLOAT32);
graph_output->shape({1, H, W, D});
graph_output->shape({N, D});
_output->index(graph_output->index());
_output->dtype(graph_output->dtype());
_output->shape({1, H, W, D});
_output->shape({N, D});

_y->dtype(loco::DataType::FLOAT32);
_y->shape({1, H, W, 1});
_y->size<loco::DataType::FLOAT32>(16);
_add->dtype(loco::DataType::FLOAT32);
_add->fusedActivationFunction(luci::FusedActFunc::NONE);
_add->x(_x);
_add->y(_y);
_add->shape({N, D});

_output->from(_add);

_x->name("input");
_output->name("output");
}

protected:
uint32_t const N = 4;
uint32_t const D = 3;

protected:
loco::Graph _g;
luci::CircleAdd *_add = nullptr;
luci::CircleInput *_x = nullptr;
luci::CircleConst *_y = nullptr;
luci::CircleOutput *_output = nullptr;
};

class ExpandBroadcastRank2ConstTest : public ExpandBroadcastConstRank2Graph, public ::testing::Test
{
public:
ExpandBroadcastRank2ConstTest() {}
};
} // namespace

TEST_F(ExpandBroadcastRank2ConstTest, remove_broadcast)
{
for (uint32_t i = 0; i < N; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);

luci::ExpandBroadcastConstPass pass;
ASSERT_TRUE(pass.run(&_g));

auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
ASSERT_NE(broadcasted_const, nullptr);

EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(0).value(), N);
EXPECT_EQ(broadcasted_const->dim(1).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * D);

for (uint32_t i = 0; i < N; ++i)
{
for (uint32_t d = 0; d < D; ++d)
{
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i * D + d), static_cast<float>(i),
std::numeric_limits<float>::min());
}
}
}

/****************************************************************************
* TESTS FOR RANK 4
****************************************************************************/

namespace
{
class ExpandBroadcastConstRank4Graph
{
public:
ExpandBroadcastConstRank4Graph()
{
_x = _g.nodes()->create<luci::CircleInput>();
_y = _g.nodes()->create<luci::CircleConst>();
_add = _g.nodes()->create<luci::CircleAdd>();
_output = _g.nodes()->create<luci::CircleOutput>();

auto graph_input = _g.inputs()->create();
graph_input->dtype(loco::DataType::FLOAT32);
graph_input->shape({N, H, W, D});
_x->index(graph_input->index());
_x->dtype(graph_input->dtype());
_x->shape({N, H, W, D});

auto graph_output = _g.outputs()->create();
graph_output->dtype(loco::DataType::FLOAT32);
graph_output->shape({N, H, W, D});
_output->index(graph_output->index());
_output->dtype(graph_output->dtype());
_output->shape({N, H, W, D});

_add->dtype(loco::DataType::FLOAT32);
_add->fusedActivationFunction(luci::FusedActFunc::NONE);
_add->x(_x);
_add->y(_y);
_add->shape({1, H, W, D});
_add->shape({N, H, W, D});

_output->from(_add);

Expand All @@ -67,6 +157,7 @@ class ExpandBroadcastConstTest : public ::testing::Test
}

protected:
uint32_t const N = 2;
uint32_t const H = 4;
uint32_t const W = 4;
uint32_t const D = 3;
Expand All @@ -79,18 +170,30 @@ class ExpandBroadcastConstTest : public ::testing::Test
luci::CircleOutput *_output = nullptr;
};

class ExpandBroadcastRank4ConstTest1 : public ExpandBroadcastConstRank4Graph, public ::testing::Test
{
public:
ExpandBroadcastRank4ConstTest1()
{
_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, H, W, 1});
_y->size<loco::DataType::FLOAT32>(N * H * W * 1);
}
};

// TODO: Add more tests for Rank4 with different broadcasting dimensions
} // namespace

TEST_F(ExpandBroadcastConstTest, name)
TEST_F(ExpandBroadcastRank4ConstTest1, name)
{
luci::ExpandBroadcastConstPass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(ExpandBroadcastConstTest, remove_broadcast)
TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast)
{
for (uint32_t i = 0; i < H * W; ++i)
for (uint32_t i = 0; i < N * H * W; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);

luci::ExpandBroadcastConstPass pass;
Expand All @@ -100,26 +203,27 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast)
ASSERT_NE(broadcasted_const, nullptr);

EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(0).value(), N);
EXPECT_EQ(broadcasted_const->dim(1).value(), H);
EXPECT_EQ(broadcasted_const->dim(2).value(), W);
EXPECT_EQ(broadcasted_const->dim(3).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

for (uint32_t i = 0; i < H * W; ++i)
for (uint32_t i = 0; i < N * H * W; ++i)
{
for (uint32_t d = 0; d < D; ++d)
{
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i + H * W * d),
static_cast<float>(i), std::numeric_limits<float>::min());
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(i * D + d), static_cast<float>(i),
std::numeric_limits<float>::min());
}
}
}

TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast_multiple_successors)
{
auto const circle_sqrt = _g.nodes()->create<luci::CircleSqrt>();
circle_sqrt->dtype(loco::DataType::FLOAT32);
circle_sqrt->shape({1, H, W, 1});
circle_sqrt->shape({N, H, W, 1});
circle_sqrt->x(_y);

luci::ExpandBroadcastConstPass pass;
Expand All @@ -131,19 +235,19 @@ TEST_F(ExpandBroadcastConstTest, remove_broadcast_multiple_successors)
ASSERT_NE(broadcasted_const, nullptr);
EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(3).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

// Check if another successor's node was left intact
ASSERT_NE(original_const, nullptr);
EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(original_const->dim(3).value(), 1);
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), H * W * 1);
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), N * H * W * 1);
}

TEST_F(ExpandBroadcastConstTest, broadcast_impossible_NEG)
TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_impossible_NEG)
{
_y->shape({1, H, W, 2});
_y->size<loco::DataType::FLOAT32>(H * W * (D - 1));
_y->shape({N, H, W, D + 1});
_y->size<loco::DataType::FLOAT32>(N * H * W * (D + 1));

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
Expand Down

0 comments on commit 15d27ea

Please sign in to comment.