Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Balyshev committed Oct 30, 2023
1 parent fa553ee commit 15d5e8b
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 18 deletions.
5 changes: 5 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "luci/CircleOptimizer.h"

#include "luci/Pass/FuseUnrolledGRUAsCustomGRU.h"
#include "luci/Pass/ConvertNCHWToNHWCPass.h"
#include "luci/Pass/ExpandBroadcastConstPass.h"
#include "luci/Pass/FoldAddV2Pass.h"
Expand Down Expand Up @@ -302,6 +303,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FusePReluPass>());
}
if (_options->query(Options::Algorithm::FuseUnrolledGRUAsCustomGRU))
{
phase.emplace_back(std::make_unique<FuseUnrolledGRUAsCustomGRUPass>());
}
if (_options->query(Options::Algorithm::FuseGelu))
{
phase.emplace_back(std::make_unique<FuseGeluPass>());
Expand Down
152 changes: 136 additions & 16 deletions compiler/luci/pass/src/FuseUnrolledGRUAsCustomGRU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "luci/Pass/FuseUnrolledGRUAsCustomGRU.h"
#include "luci/Service/CircleNodeClone.h"

#include <luci/IR/CircleNode.h>
#include <luci/Profile/CircleNodeOrigin.h>
Expand All @@ -23,6 +24,72 @@
namespace
{

template <loco::DataType T>
void copy_values(const luci::CircleConst *node, luci::CircleConst *cloned)
{
assert(T == node->dtype());
assert(T == cloned->dtype());

const auto size = node->size<T>();
cloned->size<T>(size);
for (uint32_t i = 0; i < size; i++)
cloned->at<T>(i) = node->at<T>(i);
}

luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph)
{
auto cloned = graph->nodes()->create<luci::CircleConst>();

if (cloned != nullptr)
{
// dtype/shape
cloned->dtype(node->dtype());
cloned->rank(node->rank());

// values
switch (node->dtype())
{
case loco::DataType::FLOAT32:
copy_values<loco::DataType::FLOAT32>(node, cloned);
break;

case loco::DataType::U8:
copy_values<loco::DataType::U8>(node, cloned);
break;

case loco::DataType::S8:
copy_values<loco::DataType::S8>(node, cloned);
break;

case loco::DataType::S16:
copy_values<loco::DataType::S16>(node, cloned);
break;

case loco::DataType::S32:
copy_values<loco::DataType::S32>(node, cloned);
break;

case loco::DataType::S64:
copy_values<loco::DataType::S64>(node, cloned);
break;

case loco::DataType::BOOL:
copy_values<loco::DataType::BOOL>(node, cloned);
break;

default:
assert(false);
}
}

return cloned;
}

} // namespace

namespace
{

/**
* Fuse Transpose with Mean if possible
*
Expand Down Expand Up @@ -82,15 +149,22 @@ bool const_has_value_s32(const luci::CircleConst *circle_const, int32_t value)
return false;
}

bool create_custom_op(luci::CircleWhile *while_node)
bool create_custom_op(luci::CircleStridedSlice *strided_slice_node)
{
auto while_out_node = dynamic_cast<luci::CircleWhileOut *>(strided_slice_node->input());
if (while_out_node == nullptr)
return false;

auto while_node = dynamic_cast<luci::CircleWhile *>(while_out_node->input());
if (while_node == nullptr)
return false;
auto input_node = dynamic_cast<luci::CircleNode *>(while_node->input(4));
auto input_state_node = dynamic_cast<luci::CircleNode *>(while_node->input(3));
loco::Node *weight_ih;
loco::Node *bias_ih;
luci::CircleConst *weight_ih = nullptr;
luci::CircleConst *bias_ih = nullptr;

loco::Node *weight_hh;
loco::Node *bias_hh;
luci::CircleConst *weight_hh = nullptr;
luci::CircleConst *bias_hh = nullptr;

auto input_size = input_node->dim(input_node->rank() - 1).value();
auto hidden_size = input_state_node->dim(input_state_node->rank() - 1).value();
Expand All @@ -101,22 +175,68 @@ bool create_custom_op(luci::CircleWhile *while_node)
auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
if (not fc)
continue;
if (fc->dim(fc->rank() - 1) == input_size)
auto fc_weights = dynamic_cast<luci::CircleNode *>(fc->weights());

if (fc_weights->dim(fc->rank() - 1) == input_size)
{
weight_ih = fc->weights();
bias_ih = fc->bias();
weight_ih = dynamic_cast<luci::CircleConst *>(fc->weights());
bias_ih = dynamic_cast<luci::CircleConst *>(fc->bias());
}
if (fc->dim(fc->rank() - 1) == hidden_size)
if (fc_weights->dim(fc->rank() - 1) == hidden_size)
{
weight_hh = fc->weights();
bias_hh = fc->bias();
weight_hh = dynamic_cast<luci::CircleConst *>(fc->weights());
bias_hh = dynamic_cast<luci::CircleConst *>(fc->bias());
}
}

// Create and configure new CircleMean operation.
auto fused_gru = while_node->graph()->nodes()->create<luci::CircleCustom>(6, 2);
assert(weight_hh != nullptr);
assert(weight_ih != nullptr);
assert(bias_ih != nullptr);
assert(bias_hh != nullptr);

auto weight_ih_cloned = clone_circleconst(weight_ih, strided_slice_node->graph());
luci::copy_common_attributes(weight_ih, weight_ih_cloned);

auto weight_hh_cloned = clone_circleconst(weight_hh, strided_slice_node->graph());
luci::copy_common_attributes(weight_hh, weight_hh_cloned);

auto bias_ih_cloned = clone_circleconst(bias_ih, strided_slice_node->graph());
luci::copy_common_attributes(bias_ih, bias_ih_cloned);

auto bias_hh_cloned = clone_circleconst(bias_hh, strided_slice_node->graph());
luci::copy_common_attributes(bias_hh, bias_hh_cloned);

// Create and configure new CircleCustom operation.
auto fused_gru = while_node->graph()->nodes()->create<luci::CircleCustom>(6, 1);
auto custom_out = while_node->graph()->nodes()->create<luci::CircleCustomOut>();

fused_gru->custom_code("custom_gru");
fused_gru->inputs(0, input_node);
fused_gru->inputs(1, weight_ih_cloned);
fused_gru->inputs(2, weight_hh_cloned);
fused_gru->inputs(3, bias_ih_cloned);
fused_gru->inputs(4, bias_hh_cloned);
fused_gru->inputs(5, input_state_node);

fused_gru->name("gru");

fused_gru->shape_status(luci::ShapeStatus::VALID);
fused_gru->rank(2);
fused_gru->dim(0).set(strided_slice_node->dim(0).value());
fused_gru->dim(1).set(strided_slice_node->dim(1).value());

fused_gru->dtype(loco::DataType::FLOAT32);

custom_out->input(fused_gru);
custom_out->rank(2);
custom_out->name("out");
custom_out->dim(0).set(strided_slice_node->dim(0).value());
custom_out->dim(1).set(strided_slice_node->dim(1).value());
custom_out->dtype(loco::DataType::FLOAT32);
custom_out->shape_status(luci::ShapeStatus::VALID);
custom_out->index(0);

replace(strided_slice_node).with(custom_out);

return true;
}
Expand All @@ -131,11 +251,11 @@ bool FuseUnrolledGRUAsCustomGRUPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto mean = dynamic_cast<luci::CircleWhile *>(node);
if (not mean)
auto strided_slice = dynamic_cast<luci::CircleStridedSlice *>(node);
if (not strided_slice)
continue;

if (create_custom_op(mean))
if (create_custom_op(strided_slice))
changed = true;
}

Expand Down
2 changes: 1 addition & 1 deletion onert-micro/luci-interpreter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if (ENABLE_TRAINING)
endif()

add_compile_options(-fno-exceptions)
add_compile_options(-Os)
#add_compile_options(-Os)

# AFAIK, this will enable leak sanitizer, too
if(ENABLE_SANITIZER)
Expand Down
131 changes: 131 additions & 0 deletions onert-micro/luci-interpreter/pal/common/PALGRUCell.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_PAL_GRU_H
#define LUCI_INTERPRETER_PAL_GRU_H

#include "PALUtils.h"
#include "ProcessBroadcastShapes.h"
#include "PALFullyConnected.h"
#include "PALLogistic.h"

namespace luci_interpreter_pal
{

namespace
{
void calculateGRU(const float *input_data, const float *weight_input_data,
const float *weight_hidden_data, const float *bias_input_data,
const float *bias_hidden_data, float *output_data,
const int32_t *input_shape, const int32_t *output_shape, const int32_t *weight_input_shape,
const int32_t *weight_hidden_shape,
float *output_input_data,
float *output_hidden_data, const int32_t *output_shape_fc)
{
// Calculate FC for hidden (output)
FullyConnectedParams op_params{};
float activation_min{};
float activation_max{};
luci_interpreter::kernels::calculateActivationRange(luci_interpreter::FusedActFunc::NONE,
&activation_min, &activation_max);

luci_interpreter_pal::FullyConnectedParams params{};
op_params.float_activation_min = activation_min;
op_params.float_activation_max = activation_max;

FullyConnected(op_params, output_shape, output_data, weight_hidden_shape, weight_hidden_data,
bias_hidden_data, output_shape_fc, output_hidden_data);

// Calcuate FC for input
FullyConnected(op_params, input_shape, input_data, weight_input_shape, weight_input_data,
bias_input_data, output_shape_fc, output_input_data);

int num_elements = output_shape_fc[1] / 3;

float *second_hidden_part = output_hidden_data + num_elements;
float *second_input_part = output_input_data + num_elements;

float *third_hidden_part = second_hidden_part + num_elements;
float *third_input_part = second_input_part + num_elements;

// Calculate Left part
for (int i = 0; i < num_elements; ++i)
{
output_hidden_data[i] += output_input_data[i];
}
Logistic(num_elements, output_hidden_data, output_hidden_data);

// Calculate most left add
float *most_left_part_final = output_hidden_data;
float *first_part = output_hidden_data;
for (int i = 0; i < num_elements; ++i)
{
output_data[i] *= most_left_part_final[i];
first_part[i] = 1.0f - first_part[i];
}


// Clalc third part
float *third_part = third_hidden_part;
for (int i = 0; i < num_elements; ++i)
{
third_part[i] += third_input_part[i];
}
Logistic(num_elements, third_part, third_part);

for (int i = 0; i < num_elements; ++i)
{
third_part[i] *= second_hidden_part[i];
third_part[i] += second_input_part[i];
third_part[i] = std::tanh(third_part[i]);
third_part[i] *= first_part[i];
output_data[i] += third_part[i];
}
}

} // namespace

void GRU(float *input_data, const float *weight_input_data,
const float *weight_hidden_data, const float *bias_input_data,
const float *bias_hidden_data, const float *hidden_state_data, float *output_data,
const int32_t *input_shape, const int32_t *output_shape, const int32_t *weight_input_shape,
const int32_t *weight_hidden_shape)
{
const int32_t time = input_shape[0];
input_shape += 1;

auto output_input_data = std::make_unique<float []>(weight_hidden_shape[0]);
auto output_hidden_data = std::make_unique<float []>(weight_hidden_shape[0]);

int32_t output_shape_fc[] = {1, 96};

std::memcpy(output_data, hidden_state_data, output_shape[1]);

for (int i = 0; i < time; ++i)
{
// input_shape should be (1, 6)
calculateGRU(input_data, weight_input_data, weight_hidden_data,
bias_input_data, bias_hidden_data, output_data, input_shape,
output_shape, weight_input_shape, weight_hidden_shape, output_input_data.get(),
output_hidden_data.get(), output_shape_fc);
auto tmp = input_shape[1];
input_data += input_shape[1];
}
}

} // namespace luci_interpreter_pal

#endif // LUCI_INTERPRETER_PAL_GRU_H
2 changes: 1 addition & 1 deletion onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ REGISTER_KERNEL(ADD, Add)
REGISTER_KERNEL(AVERAGE_POOL_2D, AveragePool2D)
REGISTER_KERNEL(ARG_MAX, ArgMax)
REGISTER_KERNEL(ARG_MIN, ArgMin)
REGISTER_KERNEL(CUSTOM, BroadcastTo)
REGISTER_KERNEL(CUSTOM, Custom)
REGISTER_KERNEL(BATCH_TO_SPACE_ND, BatchToSpaceND)
REGISTER_KERNEL(CEIL, Ceil)
REGISTER_KERNEL(COS, Cos)
Expand Down
Loading

0 comments on commit 15d5e8b

Please sign in to comment.