diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index fb50a4c7724..b320d5f66f6 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -121,7 +121,7 @@ int entry(int argc, char **argv) "This will fuse BatchNorm operators of pre-activations to Convolution operator"); add_switch(arser, "--fuse_prelu", "This will fuse operators to PReLU operator"); add_switch(arser, "--fuse_gelu", "This will fuse operators to GeLU operator"); - add_switch(arser, "--fuse_gru", "This will fuse operators to GRU operator"); + add_switch(arser, "--fuse_gru", "This will fuse operators to CirGru operator"); add_switch(arser, "--remove_duplicate_const", "This will remove all duplicate constant nodes"); add_switch(arser, "--remove_fakequant", "This will remove FakeQuant operators"); add_switch(arser, "--remove_quantdequant", "This will remove Quantize-Dequantize sequence"); @@ -305,7 +305,7 @@ int entry(int argc, char **argv) if (arser.get("--fuse_gelu")) options->enable(Algorithms::FuseGelu); if (arser.get("--fuse_gru")) - options->enable(Algorithms::FuseGRU); + options->enable(Algorithms::FuseCirGru); if (arser.get("--fuse_transpose_with_mean")) options->enable(Algorithms::FuseTransposeWithMean); if (arser.get("--remove_duplicate_const")) diff --git a/compiler/circlechef/circle/src/CircleOpChefs.h b/compiler/circlechef/circle/src/CircleOpChefs.h index cc7e21957b5..238b6fd3d29 100644 --- a/compiler/circlechef/circle/src/CircleOpChefs.h +++ b/compiler/circlechef/circle/src/CircleOpChefs.h @@ -21,7 +21,7 @@ #include "Op/BatchMatMul.h" #include "Op/BCQFullyConnected.h" #include "Op/BCQGather.h" -#include "Op/CircleGRU.h" +#include "Op/CirGru.h" #include "Op/InstanceNorm.h" #endif // __CIRCLE_OP_CHEFS_H__ diff --git a/compiler/circlechef/circle/src/CircleOpRegistry.h b/compiler/circlechef/circle/src/CircleOpRegistry.h index 34d0c64ba28..dcb4cfa9653 100644 --- a/compiler/circlechef/circle/src/CircleOpRegistry.h +++ b/compiler/circlechef/circle/src/CircleOpRegistry.h @@ -58,7 +58,7 @@ class CircleOpRegistry REG_TFL_OP(BATCH_MATMUL, CircleOpBatchMatMul); REG_TFL_OP(BCQ_FULLY_CONNECTED, CircleOpBCQFullyConnected); REG_TFL_OP(BCQ_GATHER, CircleOpBCQGather); - REG_TFL_OP(CIR_GRU, CircleOpCircleGRU); + REG_TFL_OP(CIR_GRU, CircleOpCirGru); REG_TFL_OP(INSTANCE_NORM, CircleOpInstanceNorm); #undef REG_TFL_OP } diff --git a/compiler/circlechef/circle/src/Convert.cpp b/compiler/circlechef/circle/src/Convert.cpp index a5eb8e9ba13..248687fedc7 100644 --- a/compiler/circlechef/circle/src/Convert.cpp +++ b/compiler/circlechef/circle/src/Convert.cpp @@ -54,10 +54,9 @@ circlechef::Activation as_circlechef_activation(const circle::ActivationFunction return circlechef::RELU; case circle::ActivationFunctionType_RELU6: return circlechef::RELU6; - case circle::ActivationFunctionType_TANH: - return circlechef::TANH; // TODO handle other types // ActivationFunctionType_RELU_N1_TO_1 + // ActivationFunctionType_TANH // ActivationFunctionType_SIGN_BIT default: throw std::runtime_error{"unsupported activation type"}; diff --git a/compiler/circlechef/circle/src/Op/CircleGRU.cpp b/compiler/circlechef/circle/src/Op/CirGru.cpp similarity index 75% rename from compiler/circlechef/circle/src/Op/CircleGRU.cpp rename to compiler/circlechef/circle/src/Op/CirGru.cpp index e03867b4370..9fe88f24dd3 100644 --- a/compiler/circlechef/circle/src/Op/CircleGRU.cpp +++ b/compiler/circlechef/circle/src/Op/CirGru.cpp @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "CircleGRU.h" +#include "CirGru.h" #include "Convert.h" namespace circlechef { -void CircleOpCircleGRU::filler(const circle::Operator *op, CircleImport *import, - circlechef::ModelRecipe *model_recipe) const +void CircleOpCirGru::filler(const circle::Operator *op, CircleImport *import, + circlechef::ModelRecipe *model_recipe) const { // index 1, 2, 3, 4, 5 maybe constant const std::vector &inputs = as_index_vector(op->inputs()); @@ -35,15 +35,15 @@ void CircleOpCircleGRU::filler(const circle::Operator *op, CircleImport *import, import->set_tensor_filler(inputs[5]); } -circlechef::Operation *CircleOpCircleGRU::build(const circle::Operator *op, CircleImport *import, - circlechef::ModelRecipe *model_recipe) const +circlechef::Operation *CircleOpCirGru::build(const circle::Operator *op, CircleImport *import, + circlechef::ModelRecipe *model_recipe) const { - auto op_params = op->builtin_options_as_CircleGRUOptions(); + auto op_params = op->builtin_options_as_CirGruOptions(); assert(op_params != nullptr); auto operation = model_recipe->add_operation(); - operation->set_type("CircleGRU"); + operation->set_type("CirGru"); auto op_options = operation->mutable_circle_gru_options(); diff --git a/compiler/circlechef/circle/src/Op/CircleGRU.h b/compiler/circlechef/circle/src/Op/CirGru.h similarity index 92% rename from compiler/circlechef/circle/src/Op/CircleGRU.h rename to compiler/circlechef/circle/src/Op/CirGru.h index fb117256a7e..5a4f8834fa5 100644 --- a/compiler/circlechef/circle/src/Op/CircleGRU.h +++ b/compiler/circlechef/circle/src/Op/CirGru.h @@ -23,9 +23,9 @@ namespace circlechef { /** - * @brief circlechef operator builder for CircleGRU + * @brief circlechef operator builder for CirGru */ -class CircleOpCircleGRU : public CircleOpChef +class CircleOpCirGru : public CircleOpChef { public: void filler(const circle::Operator *op, CircleImport *import, diff --git a/compiler/circlechef/core/src/Convert.cpp b/compiler/circlechef/core/src/Convert.cpp index 9df0852ac54..6066324b0f9 100644 --- a/compiler/circlechef/core/src/Convert.cpp +++ b/compiler/circlechef/core/src/Convert.cpp @@ -43,8 +43,6 @@ circle::ActivationFunctionType as_circle_activation(const circlechef::Activation return circle::ActivationFunctionType_RELU; case circlechef::RELU6: return circle::ActivationFunctionType_RELU6; - case circlechef::TANH: - return circle::ActivationFunctionType_TANH; default: break; } diff --git a/compiler/circlechef/core/src/Op/CircleGRU.cpp b/compiler/circlechef/core/src/Op/CirGru.cpp similarity index 78% rename from compiler/circlechef/core/src/Op/CircleGRU.cpp rename to compiler/circlechef/core/src/Op/CirGru.cpp index fb41b99e64b..f3af0927e8c 100644 --- a/compiler/circlechef/core/src/Op/CircleGRU.cpp +++ b/compiler/circlechef/core/src/Op/CirGru.cpp @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "CircleGRU.h" +#include "CirGru.h" #include "Convert.h" -flatbuffers::Offset CircleGRUChef::value(flatbuffers::FlatBufferBuilder &fbb) const +flatbuffers::Offset CirGruChef::value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); @@ -27,7 +27,7 @@ flatbuffers::Offset CircleGRUChef::value(flatbuffers::FlatBufferBuilder &f auto return_sequences = operation.circle_gru_options().return_sequences(); auto time_major = operation.circle_gru_options().time_major(); - circle::CircleGRUOptionsBuilder options_builder{fbb}; + circle::CirGruOptionsBuilder options_builder{fbb}; options_builder.add_fused_activation_function(circle_activation); options_builder.add_return_sequences(return_sequences); options_builder.add_time_major(time_major); @@ -35,7 +35,7 @@ flatbuffers::Offset CircleGRUChef::value(flatbuffers::FlatBufferBuilder &f return options_builder.Finish().Union(); } -std::unique_ptr CircleGRUChefFactory::create(const circlechef::Operation *operation) const +std::unique_ptr CirGruChefFactory::create(const circlechef::Operation *operation) const { - return std::unique_ptr{new CircleGRUChef{operation}}; + return std::unique_ptr{new CirGruChef{operation}}; } diff --git a/compiler/circlechef/core/src/Op/CircleGRU.h b/compiler/circlechef/core/src/Op/CirGru.h similarity index 79% rename from compiler/circlechef/core/src/Op/CircleGRU.h rename to compiler/circlechef/core/src/Op/CirGru.h index 9eb12c4f866..cbe845407da 100644 --- a/compiler/circlechef/core/src/Op/CircleGRU.h +++ b/compiler/circlechef/core/src/Op/CirGru.h @@ -19,10 +19,10 @@ #include "OpChef.h" -class CircleGRUChef final : public OpChef +class CirGruChef final : public OpChef { public: - explicit CircleGRUChef(const circlechef::Operation *operation) : _operation{operation} + explicit CirGruChef(const circlechef::Operation *operation) : _operation{operation} { // DO NOTHING } @@ -30,10 +30,7 @@ class CircleGRUChef final : public OpChef public: circle::BuiltinOperator code(void) const override { return circle::BuiltinOperator_CIR_GRU; } - circle::BuiltinOptions type(void) const override - { - return circle::BuiltinOptions_CircleGRUOptions; - } + circle::BuiltinOptions type(void) const override { return circle::BuiltinOptions_CirGruOptions; } flatbuffers::Offset value(flatbuffers::FlatBufferBuilder &fbb) const override; @@ -41,7 +38,7 @@ class CircleGRUChef final : public OpChef const circlechef::Operation *_operation; }; -struct CircleGRUChefFactory final : public OpChefFactory +struct CirGruChefFactory final : public OpChefFactory { std::unique_ptr create(const circlechef::Operation *operation) const override; }; diff --git a/compiler/circlechef/core/src/OpChef.def b/compiler/circlechef/core/src/OpChef.def index adf2c49cfa2..942a957360e 100644 --- a/compiler/circlechef/core/src/OpChef.def +++ b/compiler/circlechef/core/src/OpChef.def @@ -7,5 +7,5 @@ OP_CHEF(BatchMatMul, BatchMatMulChefFactory) OP_CHEF(BCQFullyConnected, BCQFullyConnectedChefFactory) OP_CHEF(BCQGather, BCQGatherChefFactory) -OP_CHEF(CircleGRU, CircleGRUChefFactory) +OP_CHEF(CirGru, CirGruChefFactory) OP_CHEF(InstanceNorm, InstanceNormChefFactory) diff --git a/compiler/circlechef/core/src/OpChefs.h b/compiler/circlechef/core/src/OpChefs.h index 8ec574d2013..6a44114baf6 100644 --- a/compiler/circlechef/core/src/OpChefs.h +++ b/compiler/circlechef/core/src/OpChefs.h @@ -20,7 +20,7 @@ #include "Op/BatchMatMul.h" #include "Op/BCQFullyConnected.h" #include "Op/BCQGather.h" -#include "Op/CircleGRU.h" +#include "Op/CirGru.h" #include "Op/InstanceNorm.h" #endif // __OP_CHEFS_H__ diff --git a/compiler/circlechef/proto/circlechef.proto b/compiler/circlechef/proto/circlechef.proto index 58b464856f1..6927f822cac 100644 --- a/compiler/circlechef/proto/circlechef.proto +++ b/compiler/circlechef/proto/circlechef.proto @@ -64,7 +64,6 @@ enum Activation { NONE = 0; RELU = 1; RELU6 = 3; - TANH = 4; } message BatchMatMulOptions { @@ -77,7 +76,7 @@ message InstanceNormOptions { optional Activation activation = 2 [default = NONE]; } -message CircleGRUOptions { +message CirGruOptions { optional Activation activation = 1 [default = NONE]; optional bool return_sequences = 2 [default = false]; optional bool time_major = 3 [default = false]; @@ -104,7 +103,7 @@ message Operation { optional InstanceNormOptions instance_norm_options = 101; optional BCQFullyConnectedOptions bcq_fully_connected_options = 102; optional BCQGatherOptions bcq_gather_options = 103; - optional CircleGRUOptions circle_gru_options = 104; + optional CirGruOptions circle_gru_options = 104; } // For additional subgraphs diff --git a/compiler/circledump/src/OpPrinter.cpp b/compiler/circledump/src/OpPrinter.cpp index 08d45bbdc88..3f91d68e270 100644 --- a/compiler/circledump/src/OpPrinter.cpp +++ b/compiler/circledump/src/OpPrinter.cpp @@ -807,12 +807,12 @@ class InstanceNormPrinter : public OpPrinter } }; -class CircleGRUPrinter : public OpPrinter +class CirGruPrinter : public OpPrinter { public: void options(const circle::Operator *op, std::ostream &os) const override { - if (auto *params = op->builtin_options_as_CircleGRUOptions()) + if (auto *params = op->builtin_options_as_CirGruOptions()) { os << " "; os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function()) @@ -911,7 +911,7 @@ OpPrinterRegistry::OpPrinterRegistry() _op_map[circle::BuiltinOperator_BCQ_FULLY_CONNECTED] = make_unique(); _op_map[circle::BuiltinOperator_BCQ_GATHER] = make_unique(); _op_map[circle::BuiltinOperator_INSTANCE_NORM] = make_unique(); - _op_map[circle::BuiltinOperator_CIR_GRU] = make_unique(); + _op_map[circle::BuiltinOperator_CIR_GRU] = make_unique(); } } // namespace circledump diff --git a/compiler/common-artifacts/exclude.lst b/compiler/common-artifacts/exclude.lst index cd5f131cd89..e6722466def 100644 --- a/compiler/common-artifacts/exclude.lst +++ b/compiler/common-artifacts/exclude.lst @@ -165,6 +165,6 @@ tcgenerate(ZerosLike_000) tcgenerate(BCQFullyConnected_000) tcgenerate(BCQFullyConnected_001) tcgenerate(BCQGather_000) -tcgenerate(CircleGRU_000) # luci-interpreter does not support custom CircleGRU +tcgenerate(CirGru_000) # luci-interpreter does not support custom CirGru tcgenerate(InstanceNorm_000) tcgenerate(InstanceNorm_001) diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h index c19a03547c4..0f876e1cb42 100644 --- a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h +++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h @@ -541,11 +541,10 @@ class BuiltinOptionsExtractor final to_circle_actfunc(node->fusedActivationFunction())) .Union(); } - flatbuffers::Offset visit(luci::CircleGRU *node) + flatbuffers::Offset visit(luci::CircleCirGru *node) { - return circle::CreateCircleGRUOptions(_builder, - to_circle_actfunc(node->fusedActivationFunction()), - node->returnSequences(), node->timeMajor()) + return circle::CreateCirGruOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()), + node->returnSequences(), node->timeMajor()) .Union(); } diff --git a/compiler/luci/export/src/CircleOps.lst b/compiler/luci/export/src/CircleOps.lst index ba454d8bfa8..d590b630077 100644 --- a/compiler/luci/export/src/CircleOps.lst +++ b/compiler/luci/export/src/CircleOps.lst @@ -139,7 +139,7 @@ CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLik CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions) CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions) CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions) -CIRCLE_NODE(CircleGRU, BuiltinOperator_CIR_GRU, BuiltinOptions_CircleGRUOptions) +CIRCLE_NODE(CircleCirGru, BuiltinOperator_CIR_GRU, BuiltinOptions_CirGruOptions) // Virtual node(s) CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut) CIRCLE_VNODE(CircleConst) diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h index 3c54183871a..609c193dea2 100644 --- a/compiler/luci/import/include/luci/Import/Nodes.h +++ b/compiler/luci/import/include/luci/Import/Nodes.h @@ -57,7 +57,7 @@ #include "Nodes/CircleGelu.h" #include "Nodes/CircleGreater.h" #include "Nodes/CircleGreaterEqual.h" -#include "Nodes/CircleGRU.h" +#include "Nodes/CircleCirGru.h" #include "Nodes/CircleHardSwish.h" #include "Nodes/CircleIf.h" #include "Nodes/CircleInstanceNorm.h" diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h b/compiler/luci/import/include/luci/Import/Nodes/CircleCirGru.h similarity index 95% rename from compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h rename to compiler/luci/import/include/luci/Import/Nodes/CircleCirGru.h index 920e53c1cc4..f4af6899ddf 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleCirGru.h @@ -22,7 +22,7 @@ namespace luci { -class CircleGRUGraphBuilder : public GraphBuilder +class CircleCirGruGraphBuilder : public GraphBuilder { public: bool validate(const ValidateArgs &args) const final; diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp index 511a90196bb..4f923e740d5 100644 --- a/compiler/luci/import/src/GraphBuilderRegistry.cpp +++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp @@ -67,7 +67,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(GREATER, CircleGreaterGraphBuilder); // 61 CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualGraphBuilder); // 62 CIRCLE_NODE(HARD_SWISH, CircleHardSwishGraphBuilder); // 117 - CIRCLE_NODE(CIR_GRU, CircleGRUGraphBuilder); // 255 + CIRCLE_NODE(CIR_GRU, CircleCirGruGraphBuilder); // 251 CIRCLE_NODE(IF, CircleIfGraphBuilder); // 118 CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormGraphBuilder); // 254 CIRCLE_NODE(L2_NORMALIZATION, CircleL2NormalizeGraphBuilder); // 11 diff --git a/compiler/luci/import/src/Nodes/CircleGRU.cpp b/compiler/luci/import/src/Nodes/CircleCirGru.cpp similarity index 71% rename from compiler/luci/import/src/Nodes/CircleGRU.cpp rename to compiler/luci/import/src/Nodes/CircleCirGru.cpp index 7dc27e68eb2..40df802c247 100644 --- a/compiler/luci/import/src/Nodes/CircleGRU.cpp +++ b/compiler/luci/import/src/Nodes/CircleCirGru.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "luci/Import/Nodes/CircleGRU.h" +#include "luci/Import/Nodes/CircleCirGru.h" #include @@ -23,16 +23,16 @@ namespace luci { -bool CircleGRUGraphBuilder::validate(const ValidateArgs &args) const +bool CircleCirGruGraphBuilder::validate(const ValidateArgs &args) const { return GraphBuilder::validate(args, 6); } -CircleNode *CircleGRUGraphBuilder::build_node(const circle::OperatorT &, - const std::vector &inputs, - loco::Graph *graph) const +CircleNode *CircleCirGruGraphBuilder::build_node(const circle::OperatorT &, + const std::vector &inputs, + loco::Graph *graph) const { - auto *node = graph->nodes()->create(); + auto *node = graph->nodes()->create(); node->input(inputs.at(0)); node->hidden_hidden(inputs.at(1)); node->hidden_hidden_bias(inputs.at(2)); diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h index d8c18702d07..94f51e12627 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h @@ -139,7 +139,7 @@ #include "Nodes/CircleBCQFullyConnected.h" #include "Nodes/CircleBCQGather.h" #include "Nodes/CircleInstanceNorm.h" -#include "Nodes/CircleGRU.h" +#include "Nodes/CircleCirGru.h" // Virtual nodes #include "Nodes/CircleConst.h" #include "Nodes/CircleInput.h" diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst index 4579e4bd5af..e5a29863d20 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst @@ -52,7 +52,7 @@ CIRCLE_NODE(GATHER_ND, CircleGatherNd) CIRCLE_NODE(GELU, CircleGelu) CIRCLE_NODE(GREATER, CircleGreater) CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqual) -CIRCLE_NODE(CIR_GRU, CircleGRU) +CIRCLE_NODE(CIR_GRU, CircleCirGru) CIRCLE_NODE(HARD_SWISH, CircleHardSwish) CIRCLE_NODE(IF, CircleIf) CIRCLE_NODE(L2_NORMALIZATION, CircleL2Normalize) diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCirGru.h similarity index 91% rename from compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h rename to compiler/luci/lang/include/luci/IR/Nodes/CircleCirGru.h index bfaa0f0f4a0..f41febfbfca 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCirGru.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef __LUCI_IR_CIRCLEGRU_H__ -#define __LUCI_IR_CIRCLEGRU_H__ +#ifndef __LUCI_IR_CIRCLE_CIR_GRU_H__ +#define __LUCI_IR_CIRCLE_CIR_GRU_H__ #include "luci/IR/CircleNodeDecl.h" #include "luci/IR/CircleOpcode.h" @@ -28,7 +28,7 @@ namespace luci /** * @brief GRU in Circle */ -class CircleGRU final : public FixedArityNode<6, CircleNodeImpl> +class CircleCirGru final : public FixedArityNode<6, CircleNodeImpl> { public: loco::Node *input(void) const { return at(0)->node(); } @@ -67,4 +67,4 @@ class CircleGRU final : public FixedArityNode<6, CircleNodeImpl -TEST(CircleGRUTest, constructor_P) +TEST(CircleCirGruTest, constructor_P) { - luci::CircleGRU gru_node; + luci::CircleCirGru gru_node; ASSERT_EQ(luci::CircleDialect::get(), gru_node.dialect()); ASSERT_EQ(luci::CircleOpcode::CIR_GRU, gru_node.opcode()); @@ -36,10 +36,10 @@ TEST(CircleGRUTest, constructor_P) ASSERT_EQ(nullptr, gru_node.state()); } -TEST(CircleGRUTest, input_NEG) +TEST(CircleCirGruTest, input_NEG) { - luci::CircleGRU gru_node; - luci::CircleGRU node; + luci::CircleCirGru gru_node; + luci::CircleCirGru node; gru_node.input(&node); ASSERT_NE(nullptr, gru_node.input()); @@ -48,9 +48,9 @@ TEST(CircleGRUTest, input_NEG) ASSERT_EQ(nullptr, gru_node.input()); } -TEST(CircleGRUTest, arity_NEG) +TEST(CircleCirGruTest, arity_NEG) { - luci::CircleGRU gru_node; + luci::CircleCirGru gru_node; ASSERT_NO_THROW(gru_node.arg(0)); ASSERT_NO_THROW(gru_node.arg(1)); @@ -61,25 +61,25 @@ TEST(CircleGRUTest, arity_NEG) ASSERT_THROW(gru_node.arg(6), std::out_of_range); } -TEST(CircleGRUTest, visit_mutable_NEG) +TEST(CircleCirGruTest, visit_mutable_NEG) { struct TestVisitor final : public luci::CircleNodeMutableVisitor { }; - luci::CircleGRU gru_node; + luci::CircleCirGru gru_node; TestVisitor tv; ASSERT_THROW(gru_node.accept(&tv), std::exception); } -TEST(CircleGRUTest, visit_NEG) +TEST(CircleCirGruTest, visit_NEG) { struct TestVisitor final : public luci::CircleNodeVisitor { }; - luci::CircleGRU gru_node; + luci::CircleCirGru gru_node; TestVisitor tv; ASSERT_THROW(gru_node.accept(&tv), std::exception); diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp index 88be4bdcc27..0fec97c8146 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp @@ -173,7 +173,7 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node) CIRCLE_NODE(GELU, CircleGeluSummaryBuilder) CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder) CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder) - CIRCLE_NODE(CIR_GRU, CircleGRUSummaryBuilder) + CIRCLE_NODE(CIR_GRU, CircleCirGruSummaryBuilder) CIRCLE_NODE(HARD_SWISH, CircleHardSwishSummaryBuilder) CIRCLE_NODE(IF, CircleIfSummaryBuilder) CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder) diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp index b4810f94002..e6e867eb8c1 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp @@ -343,15 +343,16 @@ void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci: s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs())); } -std::vector CircleGRUSummaryBuilder::get_input_names(const luci::CircleNode *) +std::vector CircleCirGruSummaryBuilder::get_input_names(const luci::CircleNode *) { return {"input", "hidden_hidden", "hidden_hidden_bias", "hidden_input", "hidden_input_bias", "state"}; } -void CircleGRUSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +void CircleCirGruSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) { - auto gru = loco::must_cast(node); + auto gru = loco::must_cast(node); s.args().append("fused_act_function", to_str(gru->fusedActivationFunction())); s.args().append("return_sequence", to_str(gru->returnSequences())); s.args().append("time_major", to_str(gru->timeMajor())); diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h index 17e5dcc7d34..372244e7a7c 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h @@ -307,7 +307,7 @@ class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBui { }; -class CircleGRUSummaryBuilder final : public CircleNodeSummaryBuilder +class CircleCirGruSummaryBuilder final : public CircleNodeSummaryBuilder { private: std::vector get_input_names(const luci::CircleNode *); diff --git a/compiler/luci/partition/include/luci/ConnectNode.h b/compiler/luci/partition/include/luci/ConnectNode.h index 65324be7188..de928f7df40 100644 --- a/compiler/luci/partition/include/luci/ConnectNode.h +++ b/compiler/luci/partition/include/luci/ConnectNode.h @@ -185,7 +185,7 @@ class ConnectNode final : public luci::CircleNodeVisitor void visit(const luci::CircleBCQFullyConnected *) final; void visit(const luci::CircleBCQGather *) final; void visit(const luci::CircleInstanceNorm *) final; - void visit(const luci::CircleGRU *) final; + void visit(const luci::CircleCirGru *) final; // NOTE CircleInput and CircleOutput are not handled here as these need // link with graph I/O diff --git a/compiler/luci/partition/src/Nodes/CircleGRU.cpp b/compiler/luci/partition/src/Nodes/CircleCirGru.cpp similarity index 87% rename from compiler/luci/partition/src/Nodes/CircleGRU.cpp rename to compiler/luci/partition/src/Nodes/CircleCirGru.cpp index fe1bb536b0b..03c038d2b21 100644 --- a/compiler/luci/partition/src/Nodes/CircleGRU.cpp +++ b/compiler/luci/partition/src/Nodes/CircleCirGru.cpp @@ -19,9 +19,9 @@ namespace { -void connect(luci::ConnectNode *cn, const luci::CircleGRU *node) +void connect(luci::ConnectNode *cn, const luci::CircleCirGru *node) { - auto *cloned = loco::must_cast(cn->find_clone(node)); + auto *cloned = loco::must_cast(cn->find_clone(node)); luci::CircleNode *input = loco::must_cast(node->input()); luci::CircleNode *hidden_input = loco::must_cast(node->hidden_input()); @@ -45,6 +45,6 @@ void connect(luci::ConnectNode *cn, const luci::CircleGRU *node) namespace luci { -void ConnectNode::visit(const luci::CircleGRU *node) { connect(this, node); } +void ConnectNode::visit(const luci::CircleCirGru *node) { connect(this, node); } } // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp b/compiler/luci/partition/src/Nodes/CircleCirGru.test.cpp similarity index 83% rename from compiler/luci/partition/src/Nodes/CircleGRU.test.cpp rename to compiler/luci/partition/src/Nodes/CircleCirGru.test.cpp index d17f39df9c2..fac598d85d4 100644 --- a/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp +++ b/compiler/luci/partition/src/Nodes/CircleCirGru.test.cpp @@ -27,7 +27,7 @@ namespace using namespace luci::test; -class NodeGraphlet : public NodeGraphletT +class NodeGraphlet : public NodeGraphletT { public: NodeGraphlet() = default; @@ -35,9 +35,9 @@ class NodeGraphlet : public NodeGraphletT public: void init(loco::Graph *g) override { - NodeGraphletT::init(g); + NodeGraphletT::init(g); - _node->fusedActivationFunction(luci::FusedActFunc::TANH); + _node->fusedActivationFunction(luci::FusedActFunc::NONE); } }; @@ -74,10 +74,10 @@ TEST(ConnectNodeTest, connect_CIRCLE_GRU) cth.prepare_inputs(&tng); auto *node = tng.node(); - ASSERT_NO_THROW(loco::must_cast(node)); + ASSERT_NO_THROW(loco::must_cast(node)); auto *clone = luci::clone_node(node, cth.graph_clone()); - ASSERT_NO_THROW(loco::must_cast(clone)); + ASSERT_NO_THROW(loco::must_cast(clone)); cth.clone_connect(node, clone); @@ -96,10 +96,10 @@ TEST(ConnectNodeTest, connect_CIRCLE_GRU_NEG) cth.prepare_inputs_miss(&tng); auto *node = tng.node(); - ASSERT_NO_THROW(loco::must_cast(node)); + ASSERT_NO_THROW(loco::must_cast(node)); auto *clone = luci::clone_node(node, cth.graph_clone()); - ASSERT_NO_THROW(loco::must_cast(clone)); + ASSERT_NO_THROW(loco::must_cast(clone)); EXPECT_ANY_THROW(cth.clone_connect(node, clone)); } diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index aa2c6087672..34b811d92ac 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -69,7 +69,7 @@ class CircleOptimizer final FuseActivationFunction, FusePRelu, FuseGelu, - FuseGRU, + FuseCirGru, ShuffleWeightTo16x1Float32, RemoveRedundantTranspose, ReplaceMulAddWithDepthwiseConv, diff --git a/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h b/compiler/luci/pass/include/luci/Pass/FuseCirGruPass.h similarity index 78% rename from compiler/luci/pass/include/luci/Pass/FuseGRUPass.h rename to compiler/luci/pass/include/luci/Pass/FuseCirGruPass.h index bb7a165e110..76c939c60a0 100644 --- a/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h +++ b/compiler/luci/pass/include/luci/Pass/FuseCirGruPass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef __LUCI_FUSE_GRU_PASS_H__ -#define __LUCI_FUSE_GRU_PASS_H__ +#ifndef __LUCI_FUSE_CIR_GRU_PASS_H__ +#define __LUCI_FUSE_CIR_GRU_PASS_H__ #include @@ -23,17 +23,17 @@ namespace luci { /** - * @brief Class to fuse certain pattern of subgraph into CircleGRU + * @brief Class to fuse certain pattern of subgraph into CircleCirGru * * For detailed subgraph pattern to be fused, please check its implementation. */ -struct FuseGRUPass final : public logo::Pass +struct FuseCirGruPass final : public logo::Pass { - const char *name(void) const final { return "luci::FuseGRUPass"; } + const char *name(void) const final { return "luci::FuseCirGruPass"; } bool run(loco::Graph *g) final; }; } // namespace luci -#endif // __LUCI_FUSE_GRU_PASS_H__ +#endif // __LUCI_FUSE_CIR_GRU_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 41ace7cda6c..380ddda6fc4 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -43,7 +43,7 @@ #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" -#include "luci/Pass/FuseGRUPass.h" +#include "luci/Pass/FuseCirGruPass.h" #include "luci/Pass/FuseSliceWithTConvPass.h" #include "luci/Pass/FuseHorizontalFullyConnectedPass.h" #include "luci/Pass/FuseTransposeWithMeanPass.h" @@ -330,9 +330,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } - if (_options->query(Options::Algorithm::FuseGRU)) + if (_options->query(Options::Algorithm::FuseCirGru)) { - phase.emplace_back(std::make_unique()); + phase.emplace_back(std::make_unique()); } if (_options->query(Options::Algorithm::FuseHorizontalFullyConnected)) { diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseCirGruPass.cpp similarity index 90% rename from compiler/luci/pass/src/FuseGRUPass.cpp rename to compiler/luci/pass/src/FuseCirGruPass.cpp index 12358a6efb1..23d809061c9 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseCirGruPass.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "luci/Pass/FuseGRUPass.h" +#include "luci/Pass/FuseCirGruPass.h" #include "helpers/NodeFiller.h" #include @@ -281,16 +281,16 @@ bool GRUPattern1::matched() return true; } -class FuseGRU final +class FuseCirGru final { public: - FuseGRU(const GRUPatternBase *p) : _p(p) {} + FuseCirGru(const GRUPatternBase *p) : _p(p) {} public: void apply(void); private: - luci::CircleGRU *create_circle_gru(loco::Graph *graph); + luci::CircleCirGru *create_circle_gru(loco::Graph *graph); private: const GRUPatternBase *_p; @@ -318,46 +318,16 @@ luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph cloned->dtype(node->dtype()); cloned->rank(node->rank()); - // values - switch (node->dtype()) - { - case loco::DataType::FLOAT32: - copy_values(node, cloned); - break; - - case loco::DataType::U8: - copy_values(node, cloned); - break; - - case loco::DataType::S8: - copy_values(node, cloned); - break; - - case loco::DataType::S16: - copy_values(node, cloned); - break; + assert(node->dtype() == loco::DataType::FLOAT32); - case loco::DataType::S32: - copy_values(node, cloned); - break; - - case loco::DataType::S64: - copy_values(node, cloned); - break; - - case loco::DataType::BOOL: - copy_values(node, cloned); - break; - - default: - assert(false); - } + // values + copy_values(node, cloned); } return cloned; } -luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) +luci::CircleCirGru *FuseCirGru::create_circle_gru(loco::Graph *graph) { assert(graph); @@ -392,8 +362,8 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) auto hidden_input_cloned = clone_circleconst(_p->_hidden_input, graph); luci::copy_common_attributes(_p->_hidden_input, hidden_input_cloned); - // Create and configure new CircleGRU - auto circle_gru = _p->_while_node->graph()->nodes()->create(); + // Create and configure new CircleCirGru + auto circle_gru = _p->_while_node->graph()->nodes()->create(); circle_gru->input(_p->_ifm); circle_gru->hidden_input(weight_ih_cloned); @@ -406,7 +376,7 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) circle_gru->returnSequences(_p->_return_sequences); circle_gru->timeMajor(_p->_time_major); - circle_gru->name("CircleGRU"); + circle_gru->name("CircleCirGru"); circle_gru->shape_status(luci::ShapeStatus::UNDEFINED); @@ -415,7 +385,7 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) return circle_gru; } -void FuseGRU::apply() +void FuseCirGru::apply() { auto graph = _p->_pattern_last_node->graph(); @@ -444,7 +414,7 @@ bool fuse_gru(luci::CircleWhileOut *while_out_node) GRUPattern1 pattern(while_out_node); if (pattern.matched()) { - FuseGRU fuse(&pattern); + FuseCirGru fuse(&pattern); fuse.apply(); return true; } @@ -457,7 +427,7 @@ bool fuse_gru(luci::CircleWhileOut *while_out_node) namespace luci { -bool FuseGRUPass::run(loco::Graph *g) +bool FuseCirGruPass::run(loco::Graph *g) { bool changed = false; diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index 887234cde5a..8e0e09e8854 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -164,7 +164,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final; // loco::TensorShape visit(const luci::CircleBCQGather *node) final; // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final; - // loco::TensorShape visit(const luci::CircleGRU *node) final; + // loco::TensorShape visit(const luci::CircleCirGru *node) final; // Virtual // loco::TensorShape visit(const luci::CircleCustomOut *node) final; diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h index f839a935d91..d01a37f2037 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h @@ -163,7 +163,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::DataType visit(const luci::CircleBCQFullyConnected *node) final; // loco::DataType visit(const luci::CircleBCQGather *node) final; // loco::DataType visit(const luci::CircleInstanceNorm *node) final; - // loco::DataType visit(const luci::CircleGRU *node) final; + // loco::DataType visit(const luci::CircleCirGru *node) final; // Virtual // loco::DataType visit(const luci::CircleInput *node) final; diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h index c5b058e0565..248796732eb 100644 --- a/compiler/luci/service/src/CircleCloneNode.h +++ b/compiler/luci/service/src/CircleCloneNode.h @@ -257,7 +257,7 @@ class CloneNode final : public luci::CircleNodeVisitor luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final; luci::CircleNode *visit(const luci::CircleBCQGather *) final; luci::CircleNode *visit(const luci::CircleInstanceNorm *) final; - luci::CircleNode *visit(const luci::CircleGRU *) final; + luci::CircleNode *visit(const luci::CircleCirGru *) final; // NOTE CircleInput and CircleOutput are not handled here as these need // link with graph I/O diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 3195bf25dd6..d929f02e69e 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -1744,7 +1744,7 @@ loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node) return loco::NodeShape{output_shape}; } -loco::NodeShape infer_circle_gru(const luci::CircleGRU *node) +loco::NodeShape infer_circle_gru(const luci::CircleCirGru *node) { loco::TensorShape output_shape; @@ -2501,7 +2501,7 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitorfeatures()); } - loco::DataType visit(const luci::CircleGRU *node) final { return luci::dtype_get(node->input()); } + loco::DataType visit(const luci::CircleCirGru *node) final + { + return luci::dtype_get(node->input()); + } loco::DataType visit(const luci::CircleIf *node) final { diff --git a/compiler/luci/service/src/Nodes/CircleGRU.cpp b/compiler/luci/service/src/Nodes/CircleCirGru.cpp similarity index 88% rename from compiler/luci/service/src/Nodes/CircleGRU.cpp rename to compiler/luci/service/src/Nodes/CircleCirGru.cpp index c72e1852868..6fe71b221fd 100644 --- a/compiler/luci/service/src/Nodes/CircleGRU.cpp +++ b/compiler/luci/service/src/Nodes/CircleCirGru.cpp @@ -19,12 +19,12 @@ namespace luci { -luci::CircleNode *CloneNode::visit(const luci::CircleGRU *node) +luci::CircleNode *CloneNode::visit(const luci::CircleCirGru *node) { if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) return nullptr; - auto *cloned = _graph->nodes()->create(); + auto *cloned = _graph->nodes()->create(); if (cloned != nullptr) { cloned->fusedActivationFunction(node->fusedActivationFunction()); diff --git a/compiler/luci/service/src/Nodes/CircleGRU.test.cpp b/compiler/luci/service/src/Nodes/CircleCirGru.test.cpp similarity index 92% rename from compiler/luci/service/src/Nodes/CircleGRU.test.cpp rename to compiler/luci/service/src/Nodes/CircleCirGru.test.cpp index d5ffe4d6916..2d023a4add6 100644 --- a/compiler/luci/service/src/Nodes/CircleGRU.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleCirGru.test.cpp @@ -32,7 +32,7 @@ TEST(ShapeRuleTest, simple_circle_gru) luci::CircleConst hidden_input; luci::CircleConst hidden_input_bias; luci::CircleConst state; - luci::CircleGRU circle_gru; + luci::CircleCirGru circle_gru; input.shape({10, 1, 4}); input.shape_status(luci::ShapeStatus::VALID); @@ -77,7 +77,7 @@ TEST(DataTypeRuleTest, simple_circle_gru) luci::CircleConst hidden_input; luci::CircleConst hidden_input_bias; luci::CircleConst state; - luci::CircleGRU circle_gru; + luci::CircleCirGru circle_gru; input.dtype(loco::DataType::FLOAT32); hidden_hidden.dtype(loco::DataType::FLOAT32); @@ -103,14 +103,14 @@ TEST(DataTypeRuleTest, simple_circle_gru) TEST(CloneNodeTest, clone_circel_gru) { auto g = loco::make_graph(); - auto node_circle_gru = g->nodes()->create(); - node_circle_gru->fusedActivationFunction(luci::FusedActFunc::TANH); + auto node_circle_gru = g->nodes()->create(); + node_circle_gru->fusedActivationFunction(luci::FusedActFunc::NONE); auto gc = loco::make_graph(); auto cloned = luci::clone_node(node_circle_gru, gc.get()); ASSERT_NE(nullptr, cloned); ASSERT_EQ(gc.get(), cloned->graph()); - auto cloned_circle_gru = dynamic_cast(cloned); + auto cloned_circle_gru = dynamic_cast(cloned); ASSERT_NE(nullptr, cloned_circle_gru); } diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst index f041a74f603..8709f5ec091 100644 --- a/compiler/luci/tests/test.lst +++ b/compiler/luci/tests/test.lst @@ -229,7 +229,7 @@ addread(BCQFullyConnected_001) addread(BCQGather_000) addread(CircleBatchMatMul_000) addread(InstanceNorm_000) -addread(CircleGRU_000) +addread(CirGru_000) addwrite(Abs_000) addwrite(Add_000) @@ -461,4 +461,4 @@ addwrite(BCQFullyConnected_001) addwrite(BCQGather_000) addwrite(CircleBatchMatMul_000) addwrite(InstanceNorm_000) -addwrite(CircleGRU_000) +addwrite(CirGru_000) diff --git a/res/CircleRecipes/CircleGRU_000/test.recipe b/res/CircleRecipes/CirGru_000/test.recipe similarity index 96% rename from res/CircleRecipes/CircleGRU_000/test.recipe rename to res/CircleRecipes/CirGru_000/test.recipe index 468ee4e3d8c..ea9363ff9e5 100644 --- a/res/CircleRecipes/CircleGRU_000/test.recipe +++ b/res/CircleRecipes/CirGru_000/test.recipe @@ -59,10 +59,10 @@ operand { shape { dim: 1 dim: 1 dim: 4 } } operation { - type: "CircleGRU" + type: "CirGru" circle_gru_options { activation: NONE - return_sequences: true + return_sequences: false time_major: false } input: "ifm" diff --git a/res/CircleRecipes/CirGru_000/test.reverse b/res/CircleRecipes/CirGru_000/test.reverse new file mode 100644 index 00000000000..e69de29bb2d diff --git a/res/CircleRecipes/CircleGRU_000/test.reverse b/res/CircleRecipes/CircleGRU_000/test.reverse deleted file mode 100644 index 9dab36e2094..00000000000 --- a/res/CircleRecipes/CircleGRU_000/test.reverse +++ /dev/null @@ -1,17 +0,0 @@ -operand { - name: "ifm" - type: FLOAT32 - shape { dim: 1 dim: 3 dim: 3 dim: 2 } -} -operand { - name: "ofm" - type: FLOAT32 - shape { dim: 1 dim: 3 dim: 3 dim: 2 } -} -operation { - type: "HardSwish" - input: "ifm" - output: "ofm" -} -input: "ifm" -output: "ofm" diff --git a/res/CircleSchema/0.7/circle_schema.fbs b/res/CircleSchema/0.7/circle_schema.fbs index 51c79a226da..9113e5e1348 100644 --- a/res/CircleSchema/0.7/circle_schema.fbs +++ b/res/CircleSchema/0.7/circle_schema.fbs @@ -617,7 +617,7 @@ union BuiltinOptions { BitcastOptions, BitwiseXorOptions, RightShiftOptions, - CircleGRUOptions = 251, + CirGruOptions = 251, BCQGatherOptions = 252, BCQFullyConnectedOptions = 253, InstanceNormOptions = 254, @@ -1433,7 +1433,7 @@ table GeluOptions { approximate: bool; } -table CircleGRUOptions { +table CirGruOptions { fused_activation_function:ActivationFunctionType; return_sequences : bool; time_major : bool;