From 9c49d99d4a165d702b1a46ccd8076ebfdd2594a8 Mon Sep 17 00:00:00 2001 From: seongwoo chae Date: Thu, 24 Oct 2024 12:44:04 +0900 Subject: [PATCH] [fme-apply] Consider no bias FullyConnected (#14250) This commit considers FullyConnected with no bias. ONE-DCO-1.0-Signed-off-by: seongwoo --- .../fme-apply/src/pass/FusePostScalePass.cpp | 46 ++++++++++++------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/compiler/fme-apply/src/pass/FusePostScalePass.cpp b/compiler/fme-apply/src/pass/FusePostScalePass.cpp index 6097eda32e2..a8519c10edd 100644 --- a/compiler/fme-apply/src/pass/FusePostScalePass.cpp +++ b/compiler/fme-apply/src/pass/FusePostScalePass.cpp @@ -248,7 +248,7 @@ struct FusePostScale final : public luci::CircleNodeMutableVisitor auto param = loco::must_cast(post_scale->inputs(1)); // FIX_PostScale_UNLESS auto filter = loco::must_cast(node->weights()); - auto bias = loco::must_cast(node->bias()); + luci::CircleConst *bias = dynamic_cast(node->bias()); uint32_t filter_o = filter->dim(0).value(); uint32_t filter_i = filter->dim(1).value(); @@ -259,26 +259,34 @@ struct FusePostScale final : public luci::CircleNodeMutableVisitor throw std::runtime_error("Mismatch between scale size and filter output channel size: " + std::to_string(filter_o) + " != " + std::to_string(param_size)); } - const auto bias_size = bias->size(); - if (bias_size != param_size) + if (bias) { - throw std::runtime_error("Mismatch between scale size and bias size: " + - std::to_string(bias_size) + " != " + std::to_string(param_size)); + const auto bias_size = bias->size(); + if (bias_size != param_size) + { + throw std::runtime_error("Mismatch between scale size and bias size: " + + std::to_string(bias_size) + " != " + std::to_string(param_size)); + } } auto cloned_fc = luci::clone_node(node, node->graph()); assert(cloned_fc != nullptr); // FIX_CALLER_UNLESS auto fused_fc = loco::must_cast(cloned_fc); auto fused_filter = luci::clone(filter); - auto fused_bias = luci::clone(bias); fused_fc->name(node->name() + "_fused_" + random_str()); fused_filter->name(filter->name() + "_fused_" + random_str()); - fused_bias->name(bias->name() + "_fused_" + random_str()); add_origin(fused_fc, luci::get_origin(node)); add_origin(fused_filter, luci::get_origin(filter)); - add_origin(fused_bias, luci::get_origin(bias)); + + luci::CircleConst *fused_bias = nullptr; + if (bias) + { + fused_bias = luci::clone(bias); + fused_bias->name(bias->name() + "_fused_" + random_str()); + add_origin(fused_bias, luci::get_origin(bias)); + } // Multiply param to weights for (uint32_t o = 0; o < filter_o; o++) @@ -294,17 +302,21 @@ struct FusePostScale final : public luci::CircleNodeMutableVisitor } } - // Multiply param to bias - for (uint32_t c = 0; c < filter_o; ++c) - { - float scale = param->at(c); - fused_bias->at(c) = - fused_bias->at(c) * scale; - } - fused_fc->input(node->input()); fused_fc->weights(fused_filter); - fused_fc->bias(fused_bias); + fused_fc->bias(node->bias()); + + if (bias) + { + // Multiply param to bias + for (uint32_t c = 0; c < filter_o; ++c) + { + float scale = param->at(c); + fused_bias->at(c) = + fused_bias->at(c) * scale; + } + fused_fc->bias(fused_bias); + } loco::replace(post_scale).with(fused_fc);