Skip to content

Commit

Permalink
[luci/import] Support GRU operation (#12745)
Browse files Browse the repository at this point in the history
This adds support for GRU operation in luci importer.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem authored Mar 13, 2024
1 parent 6765b4e commit fbcbd6f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/import/include/luci/Import/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "Nodes/CircleGelu.h"
#include "Nodes/CircleGreater.h"
#include "Nodes/CircleGreaterEqual.h"
#include "Nodes/CircleGRU.h"
#include "Nodes/CircleHardSwish.h"
#include "Nodes/CircleIf.h"
#include "Nodes/CircleInstanceNorm.h"
Expand Down
37 changes: 37 additions & 0 deletions compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_IMPORT_OP_CIRCLE_GRU_H__
#define __LUCI_IMPORT_OP_CIRCLE_GRU_H__

#include "luci/Import/GraphBuilder.h"

namespace luci
{

class CircleGRUGraphBuilder : public GraphBuilder
{
public:
bool validate(const ValidateArgs &args) const final;

private:
CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs,
loco::Graph *graph) const final;
};

} // namespace luci

#endif // __LUCI_IMPORT_OP_CIRCLE_GRU_H__
1 change: 1 addition & 0 deletions compiler/luci/import/src/GraphBuilderRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
CIRCLE_NODE(GELU, CircleGeluGraphBuilder); // 150
CIRCLE_NODE(GREATER, CircleGreaterGraphBuilder); // 61
CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualGraphBuilder); // 62
CIRCLE_NODE(GRU, CircleGRUGraphBuilder); // 251
CIRCLE_NODE(HARD_SWISH, CircleHardSwishGraphBuilder); // 117
CIRCLE_NODE(IF, CircleIfGraphBuilder); // 118
CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormGraphBuilder); // 254
Expand Down
46 changes: 46 additions & 0 deletions compiler/luci/import/src/Nodes/CircleGRU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Import/Nodes/CircleGRU.h"

#include <luci/IR/Nodes/CircleHardSwish.h>

#include <loco.h>

namespace luci
{

bool CircleGRUGraphBuilder::validate(const ValidateArgs &args) const
{
return GraphBuilder::validate(args, 6);
}

CircleNode *CircleGRUGraphBuilder::build_node(const circle::OperatorT &,
const std::vector<CircleNode *> &inputs,
loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleGRU>();
node->input(inputs.at(0));
node->hidden_hidden(inputs.at(1));
node->hidden_hidden_bias(inputs.at(2));
node->hidden_input(inputs.at(3));
node->hidden_input_bias(inputs.at(4));
node->state(inputs.at(5));

return node;
}

} // namespace luci

0 comments on commit fbcbd6f

Please sign in to comment.