Skip to content

Commit

Permalink
Revisit tet
Browse files Browse the repository at this point in the history
  • Loading branch information
dayo09 committed Jan 16, 2025
1 parent ae45a32 commit 7ccdc5e
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class ExpandBroadcastConstRank4Graph
}

protected:
uint32_t const N = 1;
uint32_t const N = 2;
uint32_t const H = 4;
uint32_t const W = 4;
uint32_t const D = 3;
Expand All @@ -177,7 +177,7 @@ class ExpandBroadcastRank4ConstTest1 : public ExpandBroadcastConstRank4Graph, pu
{
_y->dtype(loco::DataType::FLOAT32);
_y->shape({N, H, W, 1});
_y->size<loco::DataType::FLOAT32>(16);
_y->size<loco::DataType::FLOAT32>(N * H * W * 1);
}
};

Expand All @@ -193,7 +193,7 @@ TEST_F(ExpandBroadcastRank4ConstTest1, name)

TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast)
{
for (uint32_t i = 0; i < H * W; ++i)
for (uint32_t i = 0; i < N * H * W; ++i)
_y->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i);

luci::ExpandBroadcastConstPass pass;
Expand All @@ -203,12 +203,13 @@ TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast)
ASSERT_NE(broadcasted_const, nullptr);

EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(0).value(), N);
EXPECT_EQ(broadcasted_const->dim(1).value(), H);
EXPECT_EQ(broadcasted_const->dim(2).value(), W);
EXPECT_EQ(broadcasted_const->dim(3).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

for (uint32_t i = 0; i < H * W; ++i)
for (uint32_t i = 0; i < N * H * W; ++i)
{
for (uint32_t d = 0; d < D; ++d)
{
Expand All @@ -222,7 +223,7 @@ TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast_multiple_successors)
{
auto const circle_sqrt = _g.nodes()->create<luci::CircleSqrt>();
circle_sqrt->dtype(loco::DataType::FLOAT32);
circle_sqrt->shape({1, H, W, 1});
circle_sqrt->shape({N, H, W, 1});
circle_sqrt->x(_y);

luci::ExpandBroadcastConstPass pass;
Expand All @@ -234,19 +235,19 @@ TEST_F(ExpandBroadcastRank4ConstTest1, remove_broadcast_multiple_successors)
ASSERT_NE(broadcasted_const, nullptr);
EXPECT_EQ(broadcasted_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(broadcasted_const->dim(3).value(), D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), H * W * D);
EXPECT_EQ(broadcasted_const->size<loco::DataType::FLOAT32>(), N * H * W * D);

// Check if another successor's node was left intact
ASSERT_NE(original_const, nullptr);
EXPECT_EQ(original_const->dtype(), loco::DataType::FLOAT32);
EXPECT_EQ(original_const->dim(3).value(), 1);
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), H * W * 1);
EXPECT_EQ(original_const->size<loco::DataType::FLOAT32>(), N * H * W * 1);
}

TEST_F(ExpandBroadcastRank4ConstTest1, broadcast_impossible_NEG)
{
_y->shape({1, H, W, 2});
_y->size<loco::DataType::FLOAT32>(H * W * (D - 1));
_y->shape({N, H, W, D + 1});
_y->size<loco::DataType::FLOAT32>(N * H * W * (D + 1));

luci::ExpandBroadcastConstPass pass;
ASSERT_FALSE(pass.run(&_g));
Expand Down

0 comments on commit 7ccdc5e

Please sign in to comment.