Skip to content

Commit

Permalink
Extend ExpandBroadcastConstPass
Browse files Browse the repository at this point in the history
  • Loading branch information
dayo09 committed Jan 21, 2025
1 parent 2743c76 commit f7f5d8a
Showing 1 changed file with 59 additions and 8 deletions.
67 changes: 59 additions & 8 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
broadcast_dims.push_back(dim);
}

if (broadcast_dims.size() != 1 || broadcast_dims.back() != node->rank() - 1)
if (broadcast_dims.size() != 1)
{
WARN(l) << "NYI: Only depth broadcast removal is supported";
WARN(l) << "NYI: Only single dimension broadcast is supported";
return nullptr;
}

if (!(broadcast_dims.back() == 0 || broadcast_dims.back() == node->rank() - 1))
{
WARN(l) << "NYI: Only batch or depth broadcast removal is supported";
return nullptr;
}

Expand All @@ -71,17 +77,62 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
auto const node_data = &node->at<loco::DataType::FLOAT32>(0);
auto const constant_data = &constant->at<loco::DataType::FLOAT32>(0);

assert(successor->rank() >= 2 && successor->rank() <= 4);

// Virtually extend the constant node to 4D to support all cases
// (Only for index calculation)
// Example. (2, 4) -> (1, 1, 2, 4)
auto const D0 = (successor->rank() < 4) ? 1 : successor->dim(successor->rank() - 4).value();
auto const D1 = (successor->rank() < 3) ? 1 : successor->dim(successor->rank() - 3).value();
auto const D2 = (successor->rank() < 2) ? 1 : successor->dim(successor->rank() - 2).value();
auto const D3 = successor->dim(successor->rank() - 1).value();

auto tic = luci::Array4DIndex(D0, D1, D2, D3);
auto tic_orig = luci::Array4DIndex(D0, D1, D2, 1);
for (uint32_t n = 0; n < D0; ++n)
for (uint32_t h = 0; h < D1; ++h)
for (uint32_t w = 0; w < D2; ++w)
std::fill_n(constant_data + tic(n, h, w, 0), D3, node_data[tic_orig(n, h, w, 0)]);
if (broadcast_dims.back() == 0)
{
auto const D0 = (successor->rank() < 4) ? 1 : successor->dim(successor->rank() - 4).value();
auto const D1 = (successor->rank() < 3) ? 1 : successor->dim(successor->rank() - 3).value();
auto const D2 = (successor->rank() < 2) ? 1 : successor->dim(successor->rank() - 2).value();
auto const D3 = successor->dim(successor->rank() - 1).value();

auto idx = luci::Array4DIndex(D0, D1, D2, D3);

auto const D0_orig = (successor->rank() == 4) ? 1 : D0;
auto const D1_orig = (successor->rank() == 3) ? 1 : D1;
auto const D2_orig = (successor->rank() == 2) ? 1 : D2;
auto const D3_orig = D3;

auto idx_orig = luci::Array4DIndex(D0_orig, D1_orig, D2_orig, D3_orig);

for (uint32_t d0 = 0; d0 < D0; ++d0)
for (uint32_t d1 = 0; d1 < D1; ++d1)
for (uint32_t d2 = 0; d2 < D2; ++d2)
for (uint32_t d3 = 0; d3 < D3; ++d3)
{
auto const step = (D0 * D1 * D2 * D3) / (D0_orig * D1_orig * D2_orig * D3_orig);
auto const d0_orig = (D0_orig == 1) ? 0 : d0;
auto const d1_orig = (D1_orig == 1) ? 0 : d1;
auto const d2_orig = (D2_orig == 1) ? 0 : d2;
auto const d3_orig = (D3_orig == 1) ? 0 : d3;
constant_data[idx(d0, d1, d2, d3)] =
node_data[idx_orig(d0_orig, d1_orig, d2_orig, d3_orig)];
}
}
else if (broadcast_dims.back() == node->rank() - 1)
{
auto const D0 = (successor->rank() < 4) ? 1 : successor->dim(successor->rank() - 4).value();
auto const D1 = (successor->rank() < 3) ? 1 : successor->dim(successor->rank() - 3).value();
auto const D2 = (successor->rank() < 2) ? 1 : successor->dim(successor->rank() - 2).value();
auto const D3 = successor->dim(successor->rank() - 1).value();

auto idx = luci::Array4DIndex(D0, D1, D2, D3);

auto idx_orig = luci::Array4DIndex(D0, D1, D2, 1);

for (uint32_t d0 = 0; d0 < D0; ++d0)
for (uint32_t d1 = 0; d1 < D1; ++d1)
for (uint32_t d2 = 0; d2 < D2; ++d2)
std::fill_n(constant_data + idx(d0, d1, d2, 0), D3, node_data[idx_orig(d0, d1, d2, 0)]);
}

return constant;
}
Expand Down

0 comments on commit f7f5d8a

Please sign in to comment.