Skip to content

Commit

Permalink
correct one
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Balyshev committed Oct 31, 2023
1 parent 15d5e8b commit f5d27fa
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
namespace luci
{

/**
* @brief Class to fuse Mean operation with a preceding Transpose
*/
struct FuseUnrolledGRUAsCustomGRUPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FuseUnrolledGRUAsCustomGRUPass"; }
Expand Down
61 changes: 2 additions & 59 deletions compiler/luci/pass/src/FuseUnrolledGRUAsCustomGRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,65 +90,6 @@ luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph
namespace
{

/**
* Fuse Transpose with Mean if possible
*
* BEFORE
* |
* [CircleTranspose, perm<0, 2, 3, 1>]
* |
* [CircleMean, axis<3>]
* |
*
* AFTER
* | |
* [CircleMean, axis<1>] [CircleTranspose, perm<0, 2, 3, 1>]
* | |
* [CircleMean, axis<3>]
*
*/

/**
* @brief Create a const for fused reduction indices
*/
luci::CircleConst *create_fused_indices(luci::CircleConst *rindices,
const std::vector<uint32_t> &fused_rindices)
{
assert(rindices != nullptr); // FIX_CALLER_UNLESS

if (rindices->dtype() != loco::DataType::S32)
return nullptr;

assert(fused_rindices.size() == rindices->size<loco::DataType::S32>());

auto fused_rindices_const = luci::clone(rindices);
auto name = rindices->name();
assert(name.length() > 0); // FIX_CALLER_UNLESS
fused_rindices_const->name(name + "_fused");

for (uint32_t i = 0; i < fused_rindices.size(); ++i)
{
fused_rindices_const->at<loco::DataType::S32>(i) = fused_rindices.at(i);
}

return fused_rindices_const;
}

bool const_has_value_s32(const luci::CircleConst *circle_const, int32_t value)
{
if (circle_const->dtype() != loco::DataType::S32)
return false;

uint32_t size = circle_const->size<loco::DataType::S32>();
for (uint32_t i = 0; i < size; ++i)
{
if (circle_const->at<loco::DataType::S32>(i) == value)
return true;
}

return false;
}

bool create_custom_op(luci::CircleStridedSlice *strided_slice_node)
{
auto while_out_node = dynamic_cast<luci::CircleWhileOut *>(strided_slice_node->input());
Expand Down Expand Up @@ -255,6 +196,8 @@ bool FuseUnrolledGRUAsCustomGRUPass::run(loco::Graph *g)
if (not strided_slice)
continue;

// TODO: add pattern checks

if (create_custom_op(strided_slice))
changed = true;
}
Expand Down
25 changes: 11 additions & 14 deletions onert-micro/luci-interpreter/pal/common/PALGRUCell.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,21 @@ void calculateGRU(const float *input_data, const float *weight_input_data,
first_part[i] = 1.0f - first_part[i];
}


// Clalc third part
float *third_part = third_hidden_part;
// Clalc second part
float *second_part = second_hidden_part;
for (int i = 0; i < num_elements; ++i)
{
third_part[i] += third_input_part[i];
second_part[i] += second_input_part[i];
}
Logistic(num_elements, third_part, third_part);
Logistic(num_elements, second_part, second_part);

for (int i = 0; i < num_elements; ++i)
{
third_part[i] *= second_hidden_part[i];
third_part[i] += second_input_part[i];
third_part[i] = std::tanh(third_part[i]);
third_part[i] *= first_part[i];
output_data[i] += third_part[i];
second_part[i] *= third_hidden_part[i];
second_part[i] += third_input_part[i];
second_part[i] = std::tanh(second_part[i]);
second_part[i] *= first_part[i];
output_data[i] += second_part[i];
}
}

Expand All @@ -110,18 +109,16 @@ void GRU(float *input_data, const float *weight_input_data,
auto output_input_data = std::make_unique<float []>(weight_hidden_shape[0]);
auto output_hidden_data = std::make_unique<float []>(weight_hidden_shape[0]);

int32_t output_shape_fc[] = {1, 96};
int32_t output_shape_fc[] = {output_shape[0], weight_input_shape[0]};

std::memcpy(output_data, hidden_state_data, output_shape[1]);
std::memcpy(output_data, hidden_state_data, output_shape[1] * sizeof(float));

for (int i = 0; i < time; ++i)
{
// input_shape should be (1, 6)
calculateGRU(input_data, weight_input_data, weight_hidden_data,
bias_input_data, bias_hidden_data, output_data, input_shape,
output_shape, weight_input_shape, weight_hidden_shape, output_input_data.get(),
output_hidden_data.get(), output_shape_fc);
auto tmp = input_shape[1];
input_data += input_shape[1];
}
}
Expand Down
81 changes: 0 additions & 81 deletions onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst
Original file line number Diff line number Diff line change
@@ -1,87 +1,6 @@
REGISTER_KERNEL(ABS, Abs)
REGISTER_KERNEL(ADD, Add)
REGISTER_KERNEL(AVERAGE_POOL_2D, AveragePool2D)
REGISTER_KERNEL(ARG_MAX, ArgMax)
REGISTER_KERNEL(ARG_MIN, ArgMin)
REGISTER_KERNEL(CUSTOM, Custom)
REGISTER_KERNEL(BATCH_TO_SPACE_ND, BatchToSpaceND)
REGISTER_KERNEL(CEIL, Ceil)
REGISTER_KERNEL(COS, Cos)
REGISTER_KERNEL(CAST, Cast)
REGISTER_KERNEL(DIV, Div)
REGISTER_KERNEL(DEPTHWISE_CONV_2D, DepthwiseConv2D)
REGISTER_KERNEL(DEPTH_TO_SPACE, DepthToSpace)
REGISTER_KERNEL(DEQUANTIZE, Dequantize)
REGISTER_KERNEL(ADD_N, AddN)
REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected)
REGISTER_KERNEL(CONV_2D, Conv2D)
REGISTER_KERNEL(LOGISTIC, Logistic)
REGISTER_KERNEL(LOG, Log)
REGISTER_KERNEL(GATHER, Gather)
REGISTER_KERNEL(GATHER_ND, GatherND)
REGISTER_KERNEL(EXP, Exp)
REGISTER_KERNEL(GREATER, Greater)
REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)
REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)
REGISTER_KERNEL(ELU, Elu)
REGISTER_KERNEL(EQUAL, Equal)
REGISTER_KERNEL(FILL, Fill)
REGISTER_KERNEL(FLOOR, Floor)
REGISTER_KERNEL(FLOOR_DIV, FloorDiv)
REGISTER_KERNEL(FLOOR_MOD, FloorMod)
REGISTER_KERNEL(PACK, Pack)
REGISTER_KERNEL(PAD, Pad)
REGISTER_KERNEL(PADV2, PadV2)
REGISTER_KERNEL(PRELU, PRelu)
REGISTER_KERNEL(RESHAPE, Reshape)
REGISTER_KERNEL(RELU, Relu)
REGISTER_KERNEL(RELU6, Relu6)
REGISTER_KERNEL(REDUCE_PROD, ReduceCommon)
REGISTER_KERNEL(REDUCE_MAX, ReduceMax)
REGISTER_KERNEL(ROUND, Round)
REGISTER_KERNEL(LESS, Less)
REGISTER_KERNEL(L2_NORMALIZATION, L2Normalize)
REGISTER_KERNEL(L2_POOL_2D, L2Pool2D)
REGISTER_KERNEL(LESS_EQUAL, LessEqual)
REGISTER_KERNEL(LOGICAL_AND, LogicalAnd)
REGISTER_KERNEL(LOGICAL_NOT, LogicalNot)
REGISTER_KERNEL(LOGICAL_OR, LogicalOr)
REGISTER_KERNEL(LEAKY_RELU, LeakyRelu)
REGISTER_KERNEL(LOG_SOFTMAX, LogSoftmax)
REGISTER_KERNEL(MUL, Mul)
REGISTER_KERNEL(MIRROR_PAD, MirrorPad)
REGISTER_KERNEL(MAXIMUM, Maximum)
REGISTER_KERNEL(MEAN, Mean)
REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)
REGISTER_KERNEL(MINIMUM, Minimum)
REGISTER_KERNEL(CONCATENATION, Concatenation)
REGISTER_KERNEL(SHAPE, Shape)
REGISTER_KERNEL(NOT_EQUAL, NotEqual)
REGISTER_KERNEL(SIN, Sin)
REGISTER_KERNEL(SQUARED_DIFFERENCE, SquaredDifference)
REGISTER_KERNEL(SLICE, Slice)
REGISTER_KERNEL(SUB, Sub)
REGISTER_KERNEL(SPLIT, Split)
REGISTER_KERNEL(SPACE_TO_BATCH_ND, SpaceToBatchND)
REGISTER_KERNEL(STRIDED_SLICE, StridedSlice)
REGISTER_KERNEL(SPLIT_V, SplitV)
REGISTER_KERNEL(SQUARE, Square)
REGISTER_KERNEL(SQRT, Sqrt)
REGISTER_KERNEL(SPACE_TO_DEPTH, SpaceToDepth)
REGISTER_KERNEL(QUANTIZE, Quantize)
REGISTER_KERNEL(TANH, Tanh)
REGISTER_KERNEL(TRANSPOSE, Transpose)
REGISTER_KERNEL(TRANSPOSE_CONV, TransposeConv)
REGISTER_KERNEL(SOFTMAX, Softmax)
REGISTER_KERNEL(SUM, Sum)
REGISTER_KERNEL(SELECT_V2, SelectV2)
REGISTER_KERNEL(SVDF, SVDF)
REGISTER_KERNEL(WHILE, While)
REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM)
REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear)
REGISTER_KERNEL(RESIZE_NEAREST_NEIGHBOR, ResizeNearestNeighbor)
REGISTER_KERNEL(RSQRT, Rsqrt)
REGISTER_KERNEL(NEG, Neg)
REGISTER_KERNEL(ZEROS_LIKE, ZerosLike)
REGISTER_KERNEL(SQUEEZE, Squeeze)
REGISTER_KERNEL(UNPACK, Unpack)
4 changes: 2 additions & 2 deletions onert-micro/luci-interpreter/src/loader/ModuleLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void ModuleLoader::load(RuntimeModule *runtime_module, SimpleMemoryManager *memo
if (!reader.parse(model))
assert(false && "Error during parse");

for (size_t i = 0; i < reader.num_subgraph(); ++i)
for (size_t i = 0; i < 1; ++i)
{
if (!reader.select_subgraph(i))
assert(false && "Error during select subgraph");
Expand All @@ -47,7 +47,7 @@ void ModuleLoader::load(RuntimeModule *runtime_module, SimpleMemoryManager *memo

// For Dynamic Memory manager we build memory allocate/deallocate plan and then configure kernels.
// For Static Memory manager we only configure kernels.
for (size_t i = 0; i < reader.num_subgraph(); ++i)
for (size_t i = 0; i < 1; ++i)
{
auto *runtime_graph = runtime_module->getRuntimeGraphAt(i);
#ifdef USE_STATIC_ALLOC
Expand Down

0 comments on commit f5d27fa

Please sign in to comment.