Skip to content

Commit

Permalink
[luci/pass] Add unittest to ExpandBroadcastConstPass (#14581)
Browse files Browse the repository at this point in the history
Let's add more unittest to ExpandBroadcastConstPass with axis=1 (non-lastdim).

ONE-DCO-Signed-off-by: Dayoung Lee <[email protected]>
  • Loading branch information
dayo09 authored Jan 23, 2025
1 parent bce28b1 commit 2b2395a
Showing 1 changed file with 110 additions and 1 deletion.
111 changes: 110 additions & 1 deletion compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "luci/Pass/ExpandBroadcastConstPass.h"
#include "PassTestGraphs.h"
#include "helpers/ArrayIndex.h"

#include <luci/IR/CircleNodes.h>

Expand Down Expand Up @@ -114,6 +115,24 @@ TEST_F(ExpandBroadcastRank2ConstTest, remove_broadcast)
}
}

TEST_F(ExpandBroadcastRank2ConstTest, broadcast_impossible_NEG)
{
_y->shape({N, D + 1});
_y->size<loco::DataType::FLOAT32>(N * (D + 1));

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
}

TEST_F(ExpandBroadcastRank2ConstTest, broadcast_diff_rank_NEG)
{
_y->shape({N});
_y->size<loco::DataType::FLOAT32>(N);

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
}

/****************************************************************************
* TESTS FOR RANK 4
****************************************************************************/
Expand Down Expand Up @@ -181,7 +200,17 @@ class ExpandBroadcastRank4ConstTest1 : public ExpandBroadcastConstRank4Graph, pu
}
};

// TODO: Add more tests for Rank4 with different broadcasting dimensions
class ExpandBroadcastRank4ConstTest2 : public ExpandBroadcastConstRank4Graph, public ::testing::Test
{
public:
ExpandBroadcastRank4ConstTest2()
{
_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, 1, W, D});
_y->size<loco::DataType::FLOAT32>(N * 1 * W * D);
}
};

} // namespace

TEST_F(ExpandBroadcastRank4ConstTest1, name)
Expand Down Expand Up @@ -252,3 +281,83 @@ TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_impossible_NEG)
luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
}

TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_diff_rank_NEG)
{
_y->shape({N, H, W});
_y->size<loco::DataType::FLOAT32>(N * H * W);

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
}

TEST_F(ExpandBroadcastRank4ConstTest2, remove_broadcast)
{
for (uint32_t i = 0; i < N * W * D; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);

luci::ExpandBroadcastConstPass pass;
ASSERT_TRUE(pass.run(&_g));

auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
ASSERT_NE(broadcasted_const, nullptr);

EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(0).value(), N);
EXPECT_EQ(broadcasted_const->dim(1).value(), H);
EXPECT_EQ(broadcasted_const->dim(2).value(), W);
EXPECT_EQ(broadcasted_const->dim(3).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

auto const idx = luci::Array4DIndex(N, H, W, D);

for (uint32_t n = 0; n < N; ++n)
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
for (uint32_t d = 0; d < D; ++d)
EXPECT_NEAR(broadcasted_const->at<loco::DataType::FLOAT32>(idx(n, h, w, d)),
static_cast<float>(n * W * D + w * D + d), std::numeric_limits<float>::min());
}

TEST_F(ExpandBroadcastRank4ConstTest2, remove_broadcast_multiple_successors)
{
auto const circle_sqrt = _g.nodes()->create<luci::CircleSqrt>();
circle_sqrt->dtype(loco::DataType::FLOAT32);
circle_sqrt->shape({N, 1, W, D});
circle_sqrt->x(_y);

luci::ExpandBroadcastConstPass pass;
ASSERT_TRUE(pass.run(&_g));

auto broadcasted_const = dynamic_cast<luci::CircleConst *>(_add->y());
auto original_const = dynamic_cast<luci::CircleConst *>(circle_sqrt->x());

ASSERT_NE(broadcasted_const, nullptr);
EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(1).value(), H);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

// Check if another successor's node was left intact
ASSERT_NE(original_const, nullptr);
EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(original_const->dim(1).value(), 1);
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), N * 1 * W * D);
}

TEST_F(ExpandBroadcastRank4ConstTest2, broadcast_impossible_NEG)
{
_y->shape({N, H, W + 1, D + 1});
_y->size<loco::DataType::FLOAT32>(N * H * (W + 1) * (D + 1));

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
}

TEST_F(ExpandBroadcastRank4ConstTest2, broadcast_diff_rank_NEG)
{
_y->shape({N, H, W + 4});
_y->size<loco::DataType::FLOAT32>(N * H * (W + 4));

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
}

0 comments on commit 2b2395a

Please sign in to comment.