diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp index e1be9dca9c8..1c13280d0f4 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp @@ -173,6 +173,7 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node) CIRCLE_NODE(GELU, CircleGeluSummaryBuilder) CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder) CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder) + CIRCLE_NODE(GRU, CircleGRUSummaryBuilder) CIRCLE_NODE(HARD_SWISH, CircleHardSwishSummaryBuilder) CIRCLE_NODE(IF, CircleIfSummaryBuilder) CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder) diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp index ffe93e63693..b4810f94002 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp @@ -343,6 +343,20 @@ void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci: s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs())); } +std::vector<std::string> CircleGRUSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "hidden_hidden", "hidden_hidden_bias", + "hidden_input", "hidden_input_bias", "state"}; +} + +void CircleGRUSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto gru = loco::must_cast<const luci::CircleGRU *>(node); + s.args().append("fused_act_function", to_str(gru->fusedActivationFunction())); + s.args().append("return_sequence", to_str(gru->returnSequences())); + s.args().append("time_major", to_str(gru->timeMajor())); +} + std::vector<std::string> CircleBroadcastToSummaryBuilder::get_input_names(const luci::CircleNode *) { return {"input", "shape"}; diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h index d53ee96234f..17e5dcc7d34 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h @@ -307,6 +307,13 @@ class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBui { }; +class CircleGRUSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + class CircleHardSwishSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder { };