Skip to content

Commit

Permalink
[DRAFT] Support CircleGRU
Browse files Browse the repository at this point in the history
This draft supports CircleGRU in compiler.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Dec 18, 2023
1 parent e2df8b3 commit 586baac
Show file tree
Hide file tree
Showing 27 changed files with 655 additions and 0 deletions.
7 changes: 7 additions & 0 deletions compiler/luci/export/src/CircleBuiltinTypesExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,13 @@ class BuiltinOptionsExtractor final
to_circle_actfunc(node->fusedActivationFunction()))
.Union();
}
flatbuffers::Offset<void> visit(luci::CircleGRU *node)
{
return circle::CreateCircleGRUOptions(_builder,
to_circle_actfunc(node->fusedActivationFunction()),
node->returnSequences(), node->timeMajor())
.Union();
}

protected:
flatbuffers::FlatBufferBuilder &_builder;
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/export/src/CircleOps.lst
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLik
CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions)
CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions)
CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions)
CIRCLE_NODE(CircleGRU, BuiltinOperator_CIR_GRU, BuiltinOptions_CircleGRUOptions)
// Virtual node(s)
CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut)
CIRCLE_VNODE(CircleConst)
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/import/include/luci/Import/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "Nodes/CircleGelu.h"
#include "Nodes/CircleGreater.h"
#include "Nodes/CircleGreaterEqual.h"
#include "Nodes/CircleGRU.h"
#include "Nodes/CircleHardSwish.h"
#include "Nodes/CircleIf.h"
#include "Nodes/CircleInstanceNorm.h"
Expand Down
37 changes: 37 additions & 0 deletions compiler/luci/import/include/luci/Import/Nodes/CircleGRU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_IMPORT_OP_CIRCLE_GRU_H__
#define __LUCI_IMPORT_OP_CIRCLE_GRU_H__

#include "luci/Import/GraphBuilder.h"

namespace luci
{

class CircleGRUGraphBuilder : public GraphBuilder
{
public:
bool validate(const ValidateArgs &args) const final;

private:
CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs,
loco::Graph *graph) const final;
};

} // namespace luci

#endif // __LUCI_IMPORT_OP_CIRCLE_GRU_H__
1 change: 1 addition & 0 deletions compiler/luci/import/src/GraphBuilderRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
CIRCLE_NODE(GREATER, CircleGreaterGraphBuilder); // 61
CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualGraphBuilder); // 62
CIRCLE_NODE(HARD_SWISH, CircleHardSwishGraphBuilder); // 117
CIRCLE_NODE(CIR_GRU, CircleGRUGraphBuilder); // 255
CIRCLE_NODE(IF, CircleIfGraphBuilder); // 118
CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormGraphBuilder); // 254
CIRCLE_NODE(L2_NORMALIZATION, CircleL2NormalizeGraphBuilder); // 11
Expand Down
44 changes: 44 additions & 0 deletions compiler/luci/import/src/Nodes/CircleGRU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Import/Nodes/CircleGRU.h"

#include <luci/IR/Nodes/CircleHardSwish.h>

#include <loco.h>

namespace luci
{

bool CircleGRUGraphBuilder::validate(const ValidateArgs &args) const
{
return GraphBuilder::validate(args, 4);
}

CircleNode *CircleGRUGraphBuilder::build_node(const circle::OperatorT &,
const std::vector<CircleNode *> &inputs,
loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleGRU>();
node->input(inputs.at(0));
node->hidden_hidden(inputs.at(1));
node->hidden_input(inputs.at(2));
node->state(inputs.at(3));

return node;
}

} // namespace luci
1 change: 1 addition & 0 deletions compiler/luci/lang/include/luci/IR/CircleNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
#include "Nodes/CircleBCQFullyConnected.h"
#include "Nodes/CircleBCQGather.h"
#include "Nodes/CircleInstanceNorm.h"
#include "Nodes/CircleGRU.h"
// Virtual nodes
#include "Nodes/CircleConst.h"
#include "Nodes/CircleInput.h"
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/lang/include/luci/IR/CircleNodes.lst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ CIRCLE_NODE(GATHER_ND, CircleGatherNd)
CIRCLE_NODE(GELU, CircleGelu)
CIRCLE_NODE(GREATER, CircleGreater)
CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqual)
CIRCLE_NODE(CIR_GRU, CircleGRU)
CIRCLE_NODE(HARD_SWISH, CircleHardSwish)
CIRCLE_NODE(IF, CircleIf)
CIRCLE_NODE(L2_NORMALIZATION, CircleL2Normalize)
Expand Down
64 changes: 64 additions & 0 deletions compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_IR_CIRCLEGRU_H__
#define __LUCI_IR_CIRCLEGRU_H__

#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"

#include "luci/IR/CircleNodeMixins.h"

namespace luci
{

/**
* @brief GRU in Circle
*/
class CircleGRU final : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::CIR_GRU>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }

loco::Node *hidden_hidden(void) const { return at(1)->node(); }
void hidden_hidden(loco::Node *node) { at(1)->node(node); }

loco::Node *hidden_input(void) const { return at(2)->node(); }
void hidden_input(loco::Node *node) { at(2)->node(node); }

loco::Node *state(void) const { return at(3)->node(); }
void state(loco::Node *node) { at(3)->node(node); }

public:
FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }

bool returnSequences() const { return _return_sequences; }
void returnSequences(bool return_sequences) { _return_sequences = return_sequences; }

bool timeMajor() const { return _time_major; }
void timeMajor(bool time_major) { _time_major = time_major; }

private:
FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
bool _return_sequences = false;
bool _time_major = false;
};

} // namespace luci

#endif // __LUCI_IR_CIRCLEGRU_H__
82 changes: 82 additions & 0 deletions compiler/luci/lang/src/Nodes/CircleGRU.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/IR/Nodes/CircleGRU.h"

#include "luci/IR/CircleDialect.h"
#include "luci/IR/CircleNodeVisitor.h"

#include <gtest/gtest.h>

TEST(CircleGRUTest, constructor_P)
{
luci::CircleGRU gru_node;

ASSERT_EQ(luci::CircleDialect::get(), gru_node.dialect());
ASSERT_EQ(luci::CircleOpcode::CIR_GRU, gru_node.opcode());

ASSERT_EQ(nullptr, gru_node.input());
ASSERT_EQ(nullptr, gru_node.hidden_hidden());
ASSERT_EQ(nullptr, gru_node.hidden_input());
ASSERT_EQ(nullptr, gru_node.state());
}

TEST(CircleGRUTest, input_NEG)
{
luci::CircleGRU gru_node;
luci::CircleGRU node;

gru_node.input(&node);
ASSERT_NE(nullptr, gru_node.input());

gru_node.input(nullptr);
ASSERT_EQ(nullptr, gru_node.input());
}

TEST(CircleGRUTest, arity_NEG)
{
luci::CircleGRU gru_node;

ASSERT_NO_THROW(gru_node.arg(0));
ASSERT_NO_THROW(gru_node.arg(1));
ASSERT_NO_THROW(gru_node.arg(2));
ASSERT_NO_THROW(gru_node.arg(3));
ASSERT_THROW(gru_node.arg(5), std::out_of_range);
}

TEST(CircleGRUTest, visit_mutable_NEG)
{
struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
{
};

luci::CircleGRU gru_node;

TestVisitor tv;
ASSERT_THROW(gru_node.accept(&tv), std::exception);
}

TEST(CircleGRUTest, visit_NEG)
{
struct TestVisitor final : public luci::CircleNodeVisitor<void>
{
};

luci::CircleGRU gru_node;

TestVisitor tv;
ASSERT_THROW(gru_node.accept(&tv), std::exception);
}
1 change: 1 addition & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node)
CIRCLE_NODE(GELU, CircleGeluSummaryBuilder)
CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder)
CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder)
CIRCLE_NODE(CIR_GRU, CircleGRUSummaryBuilder)
CIRCLE_NODE(HARD_SWISH, CircleHardSwishSummaryBuilder)
CIRCLE_NODE(IF, CircleIfSummaryBuilder)
CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder)
Expand Down
13 changes: 13 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,19 @@ void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci:
s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs()));
}

std::vector<std::string> CircleGRUSummaryBuilder::get_input_names(const luci::CircleNode *)
{
return {"input", "hidden_hidden", "hidden_input", "state"};
}

void CircleGRUSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
{
auto gru = loco::must_cast<const luci::CircleGRU *>(node);
s.args().append("fused_act_function", to_str(gru->fusedActivationFunction()));
s.args().append("return_sequence", to_str(gru->returnSequences()));
s.args().append("time_major", to_str(gru->timeMajor()));
}

std::vector<std::string> CircleBroadcastToSummaryBuilder::get_input_names(const luci::CircleNode *)
{
return {"input", "shape"};
Expand Down
7 changes: 7 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilders.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBui
{
};

class CircleGRUSummaryBuilder final : public CircleNodeSummaryBuilder
{
private:
std::vector<std::string> get_input_names(const luci::CircleNode *);
void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
};

class CircleHardSwishSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
{
};
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/partition/include/luci/ConnectNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class ConnectNode final : public luci::CircleNodeVisitor<void>
void visit(const luci::CircleBCQFullyConnected *) final;
void visit(const luci::CircleBCQGather *) final;
void visit(const luci::CircleInstanceNorm *) final;
void visit(const luci::CircleGRU *) final;

// NOTE CircleInput and CircleOutput are not handled here as these need
// link with graph I/O
Expand Down
44 changes: 44 additions & 0 deletions compiler/luci/partition/src/Nodes/CircleGRU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/ConnectNode.h"

namespace
{

void connect(luci::ConnectNode *cn, const luci::CircleGRU *node)
{
auto *cloned = loco::must_cast<luci::CircleGRU *>(cn->find_clone(node));

luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
luci::CircleNode *hidden_input = loco::must_cast<luci::CircleNode *>(node->hidden_input());
luci::CircleNode *hidden_hidden = loco::must_cast<luci::CircleNode *>(node->hidden_hidden());
luci::CircleNode *state = loco::must_cast<luci::CircleNode *>(node->state());

cloned->input(cn->find_clone(input));
cloned->hidden_input(cn->find_clone(hidden_input));
cloned->hidden_hidden(cn->find_clone(hidden_hidden));
cloned->state(cn->find_clone(state));
}

} // namespace

namespace luci
{

void ConnectNode::visit(const luci::CircleGRU *node) { connect(this, node); }

} // namespace luci
Loading

0 comments on commit 586baac

Please sign in to comment.