diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h index d606c1b75aa..c19a03547c4 100644 --- a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h +++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h @@ -541,6 +541,13 @@ class BuiltinOptionsExtractor final to_circle_actfunc(node->fusedActivationFunction())) .Union(); } + flatbuffers::Offset visit(luci::CircleGRU *node) + { + return circle::CreateCircleGRUOptions(_builder, + to_circle_actfunc(node->fusedActivationFunction()), + node->returnSequences(), node->timeMajor()) + .Union(); + } protected: flatbuffers::FlatBufferBuilder &_builder; diff --git a/compiler/luci/export/src/CircleOps.lst b/compiler/luci/export/src/CircleOps.lst index 3133f880ff8..ba454d8bfa8 100644 --- a/compiler/luci/export/src/CircleOps.lst +++ b/compiler/luci/export/src/CircleOps.lst @@ -139,6 +139,7 @@ CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLik CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions) CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions) CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions) +CIRCLE_NODE(CircleGRU, BuiltinOperator_CIR_GRU, BuiltinOptions_CircleGRUOptions) // Virtual node(s) CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut) CIRCLE_VNODE(CircleConst) diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h index ede9c3c1b09..3c54183871a 100644 --- a/compiler/luci/import/include/luci/Import/Nodes.h +++ b/compiler/luci/import/include/luci/Import/Nodes.h @@ -57,6 +57,7 @@ #include "Nodes/CircleGelu.h" #include "Nodes/CircleGreater.h" #include "Nodes/CircleGreaterEqual.h" +#include "Nodes/CircleGRU.h" #include "Nodes/CircleHardSwish.h" #include "Nodes/CircleIf.h" #include "Nodes/CircleInstanceNorm.h" diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h b/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h new file mode 100644 index 00000000000..920e53c1cc4 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_GRU_H__ +#define __LUCI_IMPORT_OP_CIRCLE_GRU_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleGRUGraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_GRU_H__ diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp index de8fba9e4d1..511a90196bb 100644 --- a/compiler/luci/import/src/GraphBuilderRegistry.cpp +++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp @@ -67,6 +67,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(GREATER, CircleGreaterGraphBuilder); // 61 CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualGraphBuilder); // 62 CIRCLE_NODE(HARD_SWISH, CircleHardSwishGraphBuilder); // 117 + CIRCLE_NODE(CIR_GRU, CircleGRUGraphBuilder); // 255 CIRCLE_NODE(IF, CircleIfGraphBuilder); // 118 CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormGraphBuilder); // 254 CIRCLE_NODE(L2_NORMALIZATION, CircleL2NormalizeGraphBuilder); // 11 diff --git a/compiler/luci/import/src/Nodes/CircleGRU.cpp b/compiler/luci/import/src/Nodes/CircleGRU.cpp new file mode 100644 index 00000000000..6c7ba11baa9 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleGRU.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleGRU.h" + +#include + +#include + +namespace luci +{ + +bool CircleGRUGraphBuilder::validate(const ValidateArgs &args) const +{ + return GraphBuilder::validate(args, 4); +} + +CircleNode *CircleGRUGraphBuilder::build_node(const circle::OperatorT &, + const std::vector &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create(); + node->input(inputs.at(0)); + node->hidden_hidden(inputs.at(1)); + node->hidden_input(inputs.at(2)); + node->state(inputs.at(3)); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h index fe1e8ea6ba7..d8c18702d07 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h @@ -139,6 +139,7 @@ #include "Nodes/CircleBCQFullyConnected.h" #include "Nodes/CircleBCQGather.h" #include "Nodes/CircleInstanceNorm.h" +#include "Nodes/CircleGRU.h" // Virtual nodes #include "Nodes/CircleConst.h" #include "Nodes/CircleInput.h" diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst index 08d376cbd1c..4579e4bd5af 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst @@ -52,6 +52,7 @@ CIRCLE_NODE(GATHER_ND, CircleGatherNd) CIRCLE_NODE(GELU, CircleGelu) CIRCLE_NODE(GREATER, CircleGreater) CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqual) +CIRCLE_NODE(CIR_GRU, CircleGRU) CIRCLE_NODE(HARD_SWISH, CircleHardSwish) CIRCLE_NODE(IF, CircleIf) CIRCLE_NODE(L2_NORMALIZATION, CircleL2Normalize) diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h new file mode 100644 index 00000000000..3c1266f7869 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLEGRU_H__ +#define __LUCI_IR_CIRCLEGRU_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief GRU in Circle + */ +class CircleGRU final : public FixedArityNode<4, CircleNodeImpl> +{ +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *hidden_hidden(void) const { return at(1)->node(); } + void hidden_hidden(loco::Node *node) { at(1)->node(node); } + + loco::Node *hidden_input(void) const { return at(2)->node(); } + void hidden_input(loco::Node *node) { at(2)->node(node); } + + loco::Node *state(void) const { return at(3)->node(); } + void state(loco::Node *node) { at(3)->node(node); } + +public: + FusedActFunc fusedActivationFunction() const { return _fused_act_fun; } + void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; } + + bool returnSequences() const { return _return_sequences; } + void returnSequences(bool return_sequences) { _return_sequences = return_sequences; } + + bool timeMajor() const { return _time_major; } + void timeMajor(bool time_major) { _time_major = time_major; } + +private: + FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED; + bool _return_sequences = false; + bool _time_major = false; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLEGRU_H__ diff --git a/compiler/luci/lang/src/Nodes/CircleGRU.test.cpp b/compiler/luci/lang/src/Nodes/CircleGRU.test.cpp new file mode 100644 index 00000000000..4f0709c5ef8 --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleGRU.test.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleGRU.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include + +TEST(CircleGRUTest, constructor_P) +{ + luci::CircleGRU gru_node; + + ASSERT_EQ(luci::CircleDialect::get(), gru_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::CIR_GRU, gru_node.opcode()); + + ASSERT_EQ(nullptr, gru_node.input()); + ASSERT_EQ(nullptr, gru_node.hidden_hidden()); + ASSERT_EQ(nullptr, gru_node.hidden_input()); + ASSERT_EQ(nullptr, gru_node.state()); +} + +TEST(CircleGRUTest, input_NEG) +{ + luci::CircleGRU gru_node; + luci::CircleGRU node; + + gru_node.input(&node); + ASSERT_NE(nullptr, gru_node.input()); + + gru_node.input(nullptr); + ASSERT_EQ(nullptr, gru_node.input()); +} + +TEST(CircleGRUTest, arity_NEG) +{ + luci::CircleGRU gru_node; + + ASSERT_NO_THROW(gru_node.arg(0)); + ASSERT_NO_THROW(gru_node.arg(1)); + ASSERT_NO_THROW(gru_node.arg(2)); + ASSERT_NO_THROW(gru_node.arg(3)); + ASSERT_THROW(gru_node.arg(5), std::out_of_range); +} + +TEST(CircleGRUTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor + { + }; + + luci::CircleGRU gru_node; + + TestVisitor tv; + ASSERT_THROW(gru_node.accept(&tv), std::exception); +} + +TEST(CircleGRUTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor + { + }; + + luci::CircleGRU gru_node; + + TestVisitor tv; + ASSERT_THROW(gru_node.accept(&tv), std::exception); +} diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp index e1be9dca9c8..88be4bdcc27 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp @@ -173,6 +173,7 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node) CIRCLE_NODE(GELU, CircleGeluSummaryBuilder) CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder) CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder) + CIRCLE_NODE(CIR_GRU, CircleGRUSummaryBuilder) CIRCLE_NODE(HARD_SWISH, CircleHardSwishSummaryBuilder) CIRCLE_NODE(IF, CircleIfSummaryBuilder) CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder) diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp index ffe93e63693..06865404237 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp @@ -343,6 +343,19 @@ void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci: s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs())); } +std::vector CircleGRUSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "hidden_hidden", "hidden_input", "state"}; +} + +void CircleGRUSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto gru = loco::must_cast(node); + s.args().append("fused_act_function", to_str(gru->fusedActivationFunction())); + s.args().append("return_sequence", to_str(gru->returnSequences())); + s.args().append("time_major", to_str(gru->timeMajor())); +} + std::vector CircleBroadcastToSummaryBuilder::get_input_names(const luci::CircleNode *) { return {"input", "shape"}; diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h index d53ee96234f..17e5dcc7d34 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h @@ -307,6 +307,13 @@ class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBui { }; +class CircleGRUSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + class CircleHardSwishSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder { }; diff --git a/compiler/luci/partition/include/luci/ConnectNode.h b/compiler/luci/partition/include/luci/ConnectNode.h index 5f14d9fc48b..65324be7188 100644 --- a/compiler/luci/partition/include/luci/ConnectNode.h +++ b/compiler/luci/partition/include/luci/ConnectNode.h @@ -185,6 +185,7 @@ class ConnectNode final : public luci::CircleNodeVisitor void visit(const luci::CircleBCQFullyConnected *) final; void visit(const luci::CircleBCQGather *) final; void visit(const luci::CircleInstanceNorm *) final; + void visit(const luci::CircleGRU *) final; // NOTE CircleInput and CircleOutput are not handled here as these need // link with graph I/O diff --git a/compiler/luci/partition/src/Nodes/CircleGRU.cpp b/compiler/luci/partition/src/Nodes/CircleGRU.cpp new file mode 100644 index 00000000000..6ce726e4f5f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGRU.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleGRU *node) +{ + auto *cloned = loco::must_cast(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast(node->input()); + luci::CircleNode *hidden_input = loco::must_cast(node->hidden_input()); + luci::CircleNode *hidden_hidden = loco::must_cast(node->hidden_hidden()); + luci::CircleNode *state = loco::must_cast(node->state()); + + cloned->input(cn->find_clone(input)); + cloned->hidden_input(cn->find_clone(hidden_input)); + cloned->hidden_hidden(cn->find_clone(hidden_hidden)); + cloned->state(cn->find_clone(state)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleGRU *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp b/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp new file mode 100644 index 00000000000..f9238d80f42 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGRU.test.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/ConnectNode.h" + +#include "ConnectNode.test.h" + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::TANH); + } +}; + +class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<4>::init({shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->hidden_hidden(input(1)); + node()->hidden_input(input(2)); + node()->state(input(3)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_CIRCLE_GRU) +{ + TestNodeGraph tng; + tng.init({10, 1, 4}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(4, clone->arity()); + // 24 separate checks is too much + for (uint32_t i = 0; i < 4; ++i) + ASSERT_EQ(cth.inputs(i), clone->arg(i)); +} + +TEST(ConnectNodeTest, connect_CIRCLE_GRU_NEG) +{ + TestNodeGraph tng; + tng.init({10, 1, 4}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index 92c5fb04cce..887234cde5a 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -164,6 +164,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final; // loco::TensorShape visit(const luci::CircleBCQGather *node) final; // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final; + // loco::TensorShape visit(const luci::CircleGRU *node) final; // Virtual // loco::TensorShape visit(const luci::CircleCustomOut *node) final; diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h index 4f4ab0f34ab..f839a935d91 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h @@ -163,6 +163,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::DataType visit(const luci::CircleBCQFullyConnected *node) final; // loco::DataType visit(const luci::CircleBCQGather *node) final; // loco::DataType visit(const luci::CircleInstanceNorm *node) final; + // loco::DataType visit(const luci::CircleGRU *node) final; // Virtual // loco::DataType visit(const luci::CircleInput *node) final; diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h index 66ebb2dd8a6..c5b058e0565 100644 --- a/compiler/luci/service/src/CircleCloneNode.h +++ b/compiler/luci/service/src/CircleCloneNode.h @@ -257,6 +257,7 @@ class CloneNode final : public luci::CircleNodeVisitor luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final; luci::CircleNode *visit(const luci::CircleBCQGather *) final; luci::CircleNode *visit(const luci::CircleInstanceNorm *) final; + luci::CircleNode *visit(const luci::CircleGRU *) final; // NOTE CircleInput and CircleOutput are not handled here as these need // link with graph I/O diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 18cf410ac06..d509a3c6fb3 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -1742,6 +1742,27 @@ loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node) return loco::NodeShape{output_shape}; } +loco::NodeShape infer_circle_gru(const luci::CircleGRU *node) +{ + loco::TensorShape output_shape; + + const auto input_shape = luci::shape_get(node->input()).as(); + const auto state_shape = luci::shape_get(node->state()).as(); + + auto rank = input_shape.rank(); + output_shape.rank(rank); + for (uint32_t i = 0; i < rank - 1; i++) + { + output_shape.dim(i) = input_shape.dim(i); + } + output_shape.dim(rank - 1) = state_shape.dim(1); + + if (not node->returnSequences()) + output_shape.dim(0) = 1; + + return loco::NodeShape{output_shape}; +} + // Virtual loco::NodeShape infer_input(const luci::CircleInput *node) { @@ -2478,6 +2499,8 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitorfeatures()); } + loco::DataType visit(const luci::CircleGRU *node) final { return luci::dtype_get(node->input()); } + loco::DataType visit(const luci::CircleIf *node) final { // Type of If is not used. Just use input 0 diff --git a/compiler/luci/service/src/Nodes/CircleGRU.cpp b/compiler/luci/service/src/Nodes/CircleGRU.cpp new file mode 100644 index 00000000000..c72e1852868 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGRU.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleGRU *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->returnSequences(node->returnSequences()); + cloned->timeMajor(node->timeMajor()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleGRU.test.cpp b/compiler/luci/service/src/Nodes/CircleGRU.test.cpp new file mode 100644 index 00000000000..5f0281c7721 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleGRU.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include +#include +#include + +#include + +#include + +TEST(ShapeRuleTest, simple_circle_gru) +{ + luci::CircleInput input; + luci::CircleConst hidden_hidden; + luci::CircleConst hidden_input; + luci::CircleConst state; + luci::CircleGRU circle_gru; + + input.shape({10, 1, 4}); + input.shape_status(luci::ShapeStatus::VALID); + + hidden_hidden.shape({7, 32}); + hidden_hidden.shape_status(luci::ShapeStatus::VALID); + + hidden_input.shape({7, 4}); + hidden_input.shape_status(luci::ShapeStatus::VALID); + + state.shape({1, 32}); + state.shape_status(luci::ShapeStatus::VALID); + + circle_gru.input(&input); + circle_gru.hidden_hidden(&hidden_hidden); + circle_gru.hidden_input(&hidden_input); + circle_gru.state(&state); + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(&circle_gru, shape)); + ASSERT_EQ(3, shape.rank()); + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(1, shape.dim(1).value()); + ASSERT_EQ(32, shape.dim(2).value()); +} + +TEST(DataTypeRuleTest, simple_circle_gru) +{ + luci::CircleInput input; + luci::CircleConst hidden_hidden; + luci::CircleConst hidden_input; + luci::CircleConst state; + luci::CircleGRU circle_gru; + + input.dtype(loco::DataType::FLOAT32); + hidden_hidden.dtype(loco::DataType::FLOAT32); + hidden_input.dtype(loco::DataType::FLOAT32); + state.dtype(loco::DataType::FLOAT32); + + circle_gru.input(&input); + circle_gru.hidden_hidden(&hidden_hidden); + circle_gru.hidden_input(&hidden_input); + circle_gru.state(&state); + + loco::DataType dtype; + luci::tinf::Rule type_inf_rule; + + ASSERT_TRUE(type_inf_rule.infer(&circle_gru, dtype)); + ASSERT_EQ(loco::DataType::FLOAT32, dtype); +} + +TEST(CloneNodeTest, clone_circel_gru) +{ + auto g = loco::make_graph(); + auto node_circle_gru = g->nodes()->create(); + node_circle_gru->fusedActivationFunction(luci::FusedActFunc::TANH); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_circle_gru, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_circle_gru = dynamic_cast(cloned); + ASSERT_NE(nullptr, cloned_circle_gru); +} diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst index cd656576791..f041a74f603 100644 --- a/compiler/luci/tests/test.lst +++ b/compiler/luci/tests/test.lst @@ -229,6 +229,7 @@ addread(BCQFullyConnected_001) addread(BCQGather_000) addread(CircleBatchMatMul_000) addread(InstanceNorm_000) +addread(CircleGRU_000) addwrite(Abs_000) addwrite(Add_000) @@ -460,3 +461,4 @@ addwrite(BCQFullyConnected_001) addwrite(BCQGather_000) addwrite(CircleBatchMatMul_000) addwrite(InstanceNorm_000) +addwrite(CircleGRU_000) diff --git a/res/CircleRecipes/CircleGRU_000/test.recipe b/res/CircleRecipes/CircleGRU_000/test.recipe new file mode 100644 index 00000000000..b84c3704129 --- /dev/null +++ b/res/CircleRecipes/CircleGRU_000/test.recipe @@ -0,0 +1,55 @@ +operand { + name: "ifm" + type: FLOAT32 + shape { dim: 3 dim: 1 dim: 2 } +} +operand { + name: "hidden_hidden" + type: FLOAT32 + shape { dim: 12 dim: 4 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "hidden_input" + type: FLOAT32 + shape { dim: 12 dim: 2 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "state" + type: FLOAT32 + shape { dim: 1 dim: 4 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "ofm" + type: FLOAT32 + shape { dim: 1 dim: 1 dim: 4 } +} +operation { + type: "CircleGRU" + circle_gru_options { + activation: TANH + return_sequences: false + time_major: false + } + input: "ifm" + input: "hidden_hidden" + input: "hidden_input" + input: "state" + output: "ofm" +} +input: "ifm" +output: "ofm" diff --git a/res/CircleRecipes/CircleGRU_000/test.reverse b/res/CircleRecipes/CircleGRU_000/test.reverse new file mode 100644 index 00000000000..9dab36e2094 --- /dev/null +++ b/res/CircleRecipes/CircleGRU_000/test.reverse @@ -0,0 +1,17 @@ +operand { + name: "ifm" + type: FLOAT32 + shape { dim: 1 dim: 3 dim: 3 dim: 2 } +} +operand { + name: "ofm" + type: FLOAT32 + shape { dim: 1 dim: 3 dim: 3 dim: 2 } +} +operation { + type: "HardSwish" + input: "ifm" + output: "ofm" +} +input: "ifm" +output: "ofm" diff --git a/res/CircleSchema/0.7/circle_schema.fbs b/res/CircleSchema/0.7/circle_schema.fbs index c132c89c829..58af51bff7a 100644 --- a/res/CircleSchema/0.7/circle_schema.fbs +++ b/res/CircleSchema/0.7/circle_schema.fbs @@ -268,6 +268,7 @@ table Tensor { // set of acceptable options. // LINT.IfChange enum BuiltinOperator : int32 { + CIR_GRU = -5, BCQ_GATHER = -4, BCQ_FULLY_CONNECTED = -3, INSTANCE_NORM = -2, @@ -619,6 +620,7 @@ union BuiltinOptions { BCQGatherOptions = 252, BCQFullyConnectedOptions = 253, InstanceNormOptions = 254, + CircleGRUOptions = 255, } union BuiltinOptions2{ @@ -1431,6 +1433,12 @@ table GeluOptions { approximate: bool; } +table CircleGRUOptions { + fused_activation_function:ActivationFunctionType; + return_sequences : bool; + time_major : bool; +} + table DynamicUpdateSliceOptions { }