Skip to content

Commit

Permalink
fix review comments: rename CircleGRU -> CirGru. Remove unnecessary t…
Browse files Browse the repository at this point in the history
…ypes in FuseCirGruPass.
  • Loading branch information
Artem Balyshev committed Jan 16, 2024
1 parent 6ee3e3e commit 4ed032d
Show file tree
Hide file tree
Showing 46 changed files with 125 additions and 176 deletions.
4 changes: 2 additions & 2 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ int entry(int argc, char **argv)
"This will fuse BatchNorm operators of pre-activations to Convolution operator");
add_switch(arser, "--fuse_prelu", "This will fuse operators to PReLU operator");
add_switch(arser, "--fuse_gelu", "This will fuse operators to GeLU operator");
add_switch(arser, "--fuse_gru", "This will fuse operators to GRU operator");
add_switch(arser, "--fuse_gru", "This will fuse operators to CirGru operator");
add_switch(arser, "--remove_duplicate_const", "This will remove all duplicate constant nodes");
add_switch(arser, "--remove_fakequant", "This will remove FakeQuant operators");
add_switch(arser, "--remove_quantdequant", "This will remove Quantize-Dequantize sequence");
Expand Down Expand Up @@ -305,7 +305,7 @@ int entry(int argc, char **argv)
if (arser.get<bool>("--fuse_gelu"))
options->enable(Algorithms::FuseGelu);
if (arser.get<bool>("--fuse_gru"))
options->enable(Algorithms::FuseGRU);
options->enable(Algorithms::FuseCirGru);
if (arser.get<bool>("--fuse_transpose_with_mean"))
options->enable(Algorithms::FuseTransposeWithMean);
if (arser.get<bool>("--remove_duplicate_const"))
Expand Down
2 changes: 1 addition & 1 deletion compiler/circlechef/circle/src/CircleOpChefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "Op/BatchMatMul.h"
#include "Op/BCQFullyConnected.h"
#include "Op/BCQGather.h"
#include "Op/CircleGRU.h"
#include "Op/CirGru.h"
#include "Op/InstanceNorm.h"

#endif // __CIRCLE_OP_CHEFS_H__
2 changes: 1 addition & 1 deletion compiler/circlechef/circle/src/CircleOpRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CircleOpRegistry
REG_TFL_OP(BATCH_MATMUL, CircleOpBatchMatMul);
REG_TFL_OP(BCQ_FULLY_CONNECTED, CircleOpBCQFullyConnected);
REG_TFL_OP(BCQ_GATHER, CircleOpBCQGather);
REG_TFL_OP(CIR_GRU, CircleOpCircleGRU);
REG_TFL_OP(CIR_GRU, CircleOpCirGru);
REG_TFL_OP(INSTANCE_NORM, CircleOpInstanceNorm);
#undef REG_TFL_OP
}
Expand Down
3 changes: 1 addition & 2 deletions compiler/circlechef/circle/src/Convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ circlechef::Activation as_circlechef_activation(const circle::ActivationFunction
return circlechef::RELU;
case circle::ActivationFunctionType_RELU6:
return circlechef::RELU6;
case circle::ActivationFunctionType_TANH:
return circlechef::TANH;
// TODO handle other types
// ActivationFunctionType_RELU_N1_TO_1
// ActivationFunctionType_TANH
// ActivationFunctionType_SIGN_BIT
default:
throw std::runtime_error{"unsupported activation type"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
* limitations under the License.
*/

#include "CircleGRU.h"
#include "CirGru.h"

#include "Convert.h"

namespace circlechef
{

void CircleOpCircleGRU::filler(const circle::Operator *op, CircleImport *import,
circlechef::ModelRecipe *model_recipe) const
void CircleOpCirGru::filler(const circle::Operator *op, CircleImport *import,
circlechef::ModelRecipe *model_recipe) const
{
// index 1, 2, 3, 4, 5 maybe constant
const std::vector<int32_t> &inputs = as_index_vector(op->inputs());
Expand All @@ -35,15 +35,15 @@ void CircleOpCircleGRU::filler(const circle::Operator *op, CircleImport *import,
import->set_tensor_filler(inputs[5]);
}

circlechef::Operation *CircleOpCircleGRU::build(const circle::Operator *op, CircleImport *import,
circlechef::ModelRecipe *model_recipe) const
circlechef::Operation *CircleOpCirGru::build(const circle::Operator *op, CircleImport *import,
circlechef::ModelRecipe *model_recipe) const
{
auto op_params = op->builtin_options_as_CircleGRUOptions();
auto op_params = op->builtin_options_as_CirGruOptions();
assert(op_params != nullptr);

auto operation = model_recipe->add_operation();

operation->set_type("CircleGRU");
operation->set_type("CirGru");

auto op_options = operation->mutable_circle_gru_options();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ namespace circlechef
{

/**
* @brief circlechef operator builder for CircleGRU
* @brief circlechef operator builder for CirGru
*/
class CircleOpCircleGRU : public CircleOpChef
class CircleOpCirGru : public CircleOpChef
{
public:
void filler(const circle::Operator *op, CircleImport *import,
Expand Down
2 changes: 0 additions & 2 deletions compiler/circlechef/core/src/Convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ circle::ActivationFunctionType as_circle_activation(const circlechef::Activation
return circle::ActivationFunctionType_RELU;
case circlechef::RELU6:
return circle::ActivationFunctionType_RELU6;
case circlechef::TANH:
return circle::ActivationFunctionType_TANH;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
* limitations under the License.
*/

#include "CircleGRU.h"
#include "CirGru.h"

#include "Convert.h"

flatbuffers::Offset<void> CircleGRUChef::value(flatbuffers::FlatBufferBuilder &fbb) const
flatbuffers::Offset<void> CirGruChef::value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

Expand All @@ -27,15 +27,15 @@ flatbuffers::Offset<void> CircleGRUChef::value(flatbuffers::FlatBufferBuilder &f
auto return_sequences = operation.circle_gru_options().return_sequences();
auto time_major = operation.circle_gru_options().time_major();

circle::CircleGRUOptionsBuilder options_builder{fbb};
circle::CirGruOptionsBuilder options_builder{fbb};
options_builder.add_fused_activation_function(circle_activation);
options_builder.add_return_sequences(return_sequences);
options_builder.add_time_major(time_major);

return options_builder.Finish().Union();
}

std::unique_ptr<OpChef> CircleGRUChefFactory::create(const circlechef::Operation *operation) const
std::unique_ptr<OpChef> CirGruChefFactory::create(const circlechef::Operation *operation) const
{
return std::unique_ptr<OpChef>{new CircleGRUChef{operation}};
return std::unique_ptr<OpChef>{new CirGruChef{operation}};
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,26 @@

#include "OpChef.h"

class CircleGRUChef final : public OpChef
class CirGruChef final : public OpChef
{
public:
explicit CircleGRUChef(const circlechef::Operation *operation) : _operation{operation}
explicit CirGruChef(const circlechef::Operation *operation) : _operation{operation}
{
// DO NOTHING
}

public:
circle::BuiltinOperator code(void) const override { return circle::BuiltinOperator_CIR_GRU; }

circle::BuiltinOptions type(void) const override
{
return circle::BuiltinOptions_CircleGRUOptions;
}
circle::BuiltinOptions type(void) const override { return circle::BuiltinOptions_CirGruOptions; }

flatbuffers::Offset<void> value(flatbuffers::FlatBufferBuilder &fbb) const override;

private:
const circlechef::Operation *_operation;
};

struct CircleGRUChefFactory final : public OpChefFactory
struct CirGruChefFactory final : public OpChefFactory
{
std::unique_ptr<OpChef> create(const circlechef::Operation *operation) const override;
};
Expand Down
2 changes: 1 addition & 1 deletion compiler/circlechef/core/src/OpChef.def
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
OP_CHEF(BatchMatMul, BatchMatMulChefFactory)
OP_CHEF(BCQFullyConnected, BCQFullyConnectedChefFactory)
OP_CHEF(BCQGather, BCQGatherChefFactory)
OP_CHEF(CircleGRU, CircleGRUChefFactory)
OP_CHEF(CirGru, CirGruChefFactory)
OP_CHEF(InstanceNorm, InstanceNormChefFactory)
2 changes: 1 addition & 1 deletion compiler/circlechef/core/src/OpChefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "Op/BatchMatMul.h"
#include "Op/BCQFullyConnected.h"
#include "Op/BCQGather.h"
#include "Op/CircleGRU.h"
#include "Op/CirGru.h"
#include "Op/InstanceNorm.h"

#endif // __OP_CHEFS_H__
5 changes: 2 additions & 3 deletions compiler/circlechef/proto/circlechef.proto
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ enum Activation {
NONE = 0;
RELU = 1;
RELU6 = 3;
TANH = 4;
}

message BatchMatMulOptions {
Expand All @@ -77,7 +76,7 @@ message InstanceNormOptions {
optional Activation activation = 2 [default = NONE];
}

message CircleGRUOptions {
message CirGruOptions {
optional Activation activation = 1 [default = NONE];
optional bool return_sequences = 2 [default = false];
optional bool time_major = 3 [default = false];
Expand All @@ -104,7 +103,7 @@ message Operation {
optional InstanceNormOptions instance_norm_options = 101;
optional BCQFullyConnectedOptions bcq_fully_connected_options = 102;
optional BCQGatherOptions bcq_gather_options = 103;
optional CircleGRUOptions circle_gru_options = 104;
optional CirGruOptions circle_gru_options = 104;
}

// For additional subgraphs
Expand Down
6 changes: 3 additions & 3 deletions compiler/circledump/src/OpPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,12 +807,12 @@ class InstanceNormPrinter : public OpPrinter
}
};

class CircleGRUPrinter : public OpPrinter
class CirGruPrinter : public OpPrinter
{
public:
void options(const circle::Operator *op, std::ostream &os) const override
{
if (auto *params = op->builtin_options_as_CircleGRUOptions())
if (auto *params = op->builtin_options_as_CirGruOptions())
{
os << " ";
os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
Expand Down Expand Up @@ -911,7 +911,7 @@ OpPrinterRegistry::OpPrinterRegistry()
_op_map[circle::BuiltinOperator_BCQ_FULLY_CONNECTED] = make_unique<BCQFullyConnectedPrinter>();
_op_map[circle::BuiltinOperator_BCQ_GATHER] = make_unique<BCQGatherPrinter>();
_op_map[circle::BuiltinOperator_INSTANCE_NORM] = make_unique<InstanceNormPrinter>();
_op_map[circle::BuiltinOperator_CIR_GRU] = make_unique<CircleGRUPrinter>();
_op_map[circle::BuiltinOperator_CIR_GRU] = make_unique<CirGruPrinter>();
}

} // namespace circledump
2 changes: 1 addition & 1 deletion compiler/common-artifacts/exclude.lst
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,6 @@ tcgenerate(ZerosLike_000)
tcgenerate(BCQFullyConnected_000)
tcgenerate(BCQFullyConnected_001)
tcgenerate(BCQGather_000)
tcgenerate(CircleGRU_000) # luci-interpreter does not support custom CircleGRU
tcgenerate(CirGru_000) # luci-interpreter does not support custom CirGru
tcgenerate(InstanceNorm_000)
tcgenerate(InstanceNorm_001)
7 changes: 3 additions & 4 deletions compiler/luci/export/src/CircleBuiltinTypesExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,10 @@ class BuiltinOptionsExtractor final
to_circle_actfunc(node->fusedActivationFunction()))
.Union();
}
flatbuffers::Offset<void> visit(luci::CircleGRU *node)
flatbuffers::Offset<void> visit(luci::CircleCirGru *node)
{
return circle::CreateCircleGRUOptions(_builder,
to_circle_actfunc(node->fusedActivationFunction()),
node->returnSequences(), node->timeMajor())
return circle::CreateCirGruOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()),
node->returnSequences(), node->timeMajor())
.Union();
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/export/src/CircleOps.lst
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLik
CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions)
CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions)
CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions)
CIRCLE_NODE(CircleGRU, BuiltinOperator_CIR_GRU, BuiltinOptions_CircleGRUOptions)
CIRCLE_NODE(CircleCirGru, BuiltinOperator_CIR_GRU, BuiltinOptions_CirGruOptions)
// Virtual node(s)
CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut)
CIRCLE_VNODE(CircleConst)
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/import/include/luci/Import/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
#include "Nodes/CircleGelu.h"
#include "Nodes/CircleGreater.h"
#include "Nodes/CircleGreaterEqual.h"
#include "Nodes/CircleGRU.h"
#include "Nodes/CircleCirGru.h"
#include "Nodes/CircleHardSwish.h"
#include "Nodes/CircleIf.h"
#include "Nodes/CircleInstanceNorm.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
namespace luci
{

class CircleGRUGraphBuilder : public GraphBuilder
class CircleCirGruGraphBuilder : public GraphBuilder
{
public:
bool validate(const ValidateArgs &args) const final;
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/import/src/GraphBuilderRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
CIRCLE_NODE(GREATER, CircleGreaterGraphBuilder); // 61
CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualGraphBuilder); // 62
CIRCLE_NODE(HARD_SWISH, CircleHardSwishGraphBuilder); // 117
CIRCLE_NODE(CIR_GRU, CircleGRUGraphBuilder); // 255
CIRCLE_NODE(CIR_GRU, CircleCirGruGraphBuilder); // 251
CIRCLE_NODE(IF, CircleIfGraphBuilder); // 118
CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormGraphBuilder); // 254
CIRCLE_NODE(L2_NORMALIZATION, CircleL2NormalizeGraphBuilder); // 11
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "luci/Import/Nodes/CircleGRU.h"
#include "luci/Import/Nodes/CircleCirGru.h"

#include <luci/IR/Nodes/CircleHardSwish.h>

Expand All @@ -23,16 +23,16 @@
namespace luci
{

bool CircleGRUGraphBuilder::validate(const ValidateArgs &args) const
bool CircleCirGruGraphBuilder::validate(const ValidateArgs &args) const
{
return GraphBuilder::validate(args, 6);
}

CircleNode *CircleGRUGraphBuilder::build_node(const circle::OperatorT &,
const std::vector<CircleNode *> &inputs,
loco::Graph *graph) const
CircleNode *CircleCirGruGraphBuilder::build_node(const circle::OperatorT &,
const std::vector<CircleNode *> &inputs,
loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleGRU>();
auto *node = graph->nodes()->create<CircleCirGru>();
node->input(inputs.at(0));
node->hidden_hidden(inputs.at(1));
node->hidden_hidden_bias(inputs.at(2));
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/lang/include/luci/IR/CircleNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
#include "Nodes/CircleBCQFullyConnected.h"
#include "Nodes/CircleBCQGather.h"
#include "Nodes/CircleInstanceNorm.h"
#include "Nodes/CircleGRU.h"
#include "Nodes/CircleCirGru.h"
// Virtual nodes
#include "Nodes/CircleConst.h"
#include "Nodes/CircleInput.h"
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/lang/include/luci/IR/CircleNodes.lst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ CIRCLE_NODE(GATHER_ND, CircleGatherNd)
CIRCLE_NODE(GELU, CircleGelu)
CIRCLE_NODE(GREATER, CircleGreater)
CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqual)
CIRCLE_NODE(CIR_GRU, CircleGRU)
CIRCLE_NODE(CIR_GRU, CircleCirGru)
CIRCLE_NODE(HARD_SWISH, CircleHardSwish)
CIRCLE_NODE(IF, CircleIf)
CIRCLE_NODE(L2_NORMALIZATION, CircleL2Normalize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef __LUCI_IR_CIRCLEGRU_H__
#define __LUCI_IR_CIRCLEGRU_H__
#ifndef __LUCI_IR_CIRCLE_CIR_GRU_H__
#define __LUCI_IR_CIRCLE_CIR_GRU_H__

#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
Expand All @@ -28,7 +28,7 @@ namespace luci
/**
* @brief GRU in Circle
*/
class CircleGRU final : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::CIR_GRU>>
class CircleCirGru final : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::CIR_GRU>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
Expand Down Expand Up @@ -67,4 +67,4 @@ class CircleGRU final : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::CI

} // namespace luci

#endif // __LUCI_IR_CIRCLEGRU_H__
#endif // __LUCI_IR_CIRCLE_CIR_GRU_H__
Loading

0 comments on commit 4ed032d

Please sign in to comment.