Skip to content

Commit

Permalink
[fme-apply] Consider no bias FullyConnected (#14250)
Browse files Browse the repository at this point in the history
This commit considers FullyConnected with no bias.

ONE-DCO-1.0-Signed-off-by: seongwoo <[email protected]>
  • Loading branch information
mhs4670go authored Oct 24, 2024
1 parent e24c111 commit 9c49d99
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions compiler/fme-apply/src/pass/FusePostScalePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ struct FusePostScale final : public luci::CircleNodeMutableVisitor<bool>
auto param =
loco::must_cast<luci::CircleConst *>(post_scale->inputs(1)); // FIX_PostScale_UNLESS
auto filter = loco::must_cast<luci::CircleConst *>(node->weights());
auto bias = loco::must_cast<luci::CircleConst *>(node->bias());
luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());

uint32_t filter_o = filter->dim(0).value();
uint32_t filter_i = filter->dim(1).value();
Expand All @@ -259,26 +259,34 @@ struct FusePostScale final : public luci::CircleNodeMutableVisitor<bool>
throw std::runtime_error("Mismatch between scale size and filter output channel size: " +
std::to_string(filter_o) + " != " + std::to_string(param_size));
}
const auto bias_size = bias->size<loco::DataType::FLOAT32>();
if (bias_size != param_size)
if (bias)
{
throw std::runtime_error("Mismatch between scale size and bias size: " +
std::to_string(bias_size) + " != " + std::to_string(param_size));
const auto bias_size = bias->size<loco::DataType::FLOAT32>();
if (bias_size != param_size)
{
throw std::runtime_error("Mismatch between scale size and bias size: " +
std::to_string(bias_size) + " != " + std::to_string(param_size));
}
}

auto cloned_fc = luci::clone_node(node, node->graph());
assert(cloned_fc != nullptr); // FIX_CALLER_UNLESS
auto fused_fc = loco::must_cast<luci::CircleFullyConnected *>(cloned_fc);
auto fused_filter = luci::clone(filter);
auto fused_bias = luci::clone(bias);

fused_fc->name(node->name() + "_fused_" + random_str());
fused_filter->name(filter->name() + "_fused_" + random_str());
fused_bias->name(bias->name() + "_fused_" + random_str());

add_origin(fused_fc, luci::get_origin(node));
add_origin(fused_filter, luci::get_origin(filter));
add_origin(fused_bias, luci::get_origin(bias));

luci::CircleConst *fused_bias = nullptr;
if (bias)
{
fused_bias = luci::clone(bias);
fused_bias->name(bias->name() + "_fused_" + random_str());
add_origin(fused_bias, luci::get_origin(bias));
}

// Multiply param to weights
for (uint32_t o = 0; o < filter_o; o++)
Expand All @@ -294,17 +302,21 @@ struct FusePostScale final : public luci::CircleNodeMutableVisitor<bool>
}
}

// Multiply param to bias
for (uint32_t c = 0; c < filter_o; ++c)
{
float scale = param->at<loco::DataType::FLOAT32>(c);
fused_bias->at<loco::DataType::FLOAT32>(c) =
fused_bias->at<loco::DataType::FLOAT32>(c) * scale;
}

fused_fc->input(node->input());
fused_fc->weights(fused_filter);
fused_fc->bias(fused_bias);
fused_fc->bias(node->bias());

if (bias)
{
// Multiply param to bias
for (uint32_t c = 0; c < filter_o; ++c)
{
float scale = param->at<loco::DataType::FLOAT32>(c);
fused_bias->at<loco::DataType::FLOAT32>(c) =
fused_bias->at<loco::DataType::FLOAT32>(c) * scale;
}
fused_fc->bias(fused_bias);
}

loco::replace(post_scale).with(fused_fc);

Expand Down

0 comments on commit 9c49d99

Please sign in to comment.