Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert-micro] Add GRU backward execution #13757

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions onert-micro/eval-driver/TrainingDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,10 @@ int entry(int argc, char **argv)
config.wof_ptr = nullptr;

// Set user defined training settings
const uint32_t training_epochs = 30;
const uint32_t training_epochs = 50;
const float lambda = 0.001f;
const uint32_t BATCH_SIZE = 32;
const uint32_t INPUT_SIZE = 180;
const uint32_t OUTPUT_SIZE = 4;
const uint32_t num_train_layers = 10;
const uint32_t BATCH_SIZE = 64;
const uint32_t num_train_layers = 4;
const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY;
const onert_micro::OMTrainOptimizer train_optim = onert_micro::ADAM;
const float beta = 0.9;
Expand All @@ -211,6 +209,9 @@ int entry(int argc, char **argv)
onert_micro::OMTrainingInterpreter train_interpreter;
train_interpreter.importTrainModel(circle_model.data(), config);

const uint32_t OUTPUT_SIZE = train_interpreter.getOutputSizeAt(0);
const uint32_t INPUT_SIZE = train_interpreter.getInputSizeAt(0);

// Temporary buffer to read input data from file using BATCH_SIZE
float training_input[BATCH_SIZE * INPUT_SIZE];
float training_target[BATCH_SIZE * OUTPUT_SIZE];
Expand Down Expand Up @@ -263,11 +264,11 @@ int entry(int argc, char **argv)
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to introduce these changes to tests for target board

reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::ACCURACY,
reinterpret_cast<void *>(&accuracy), cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
Expand All @@ -276,10 +277,10 @@ int entry(int argc, char **argv)
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS,
reinterpret_cast<void *>(&mse), cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS,
reinterpret_cast<void *>(&mae), cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
Expand Down Expand Up @@ -335,11 +336,11 @@ int entry(int argc, char **argv)
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::ACCURACY,
reinterpret_cast<void *>(&accuracy), cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
Expand All @@ -348,10 +349,10 @@ int entry(int argc, char **argv)
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS,
reinterpret_cast<void *>(&mse), cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS,
reinterpret_cast<void *>(&mae), cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace train
namespace pal
{

// Note: dloss_dweight_data should be initialized
void inline FullyConnectedWeightGrad(
const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape,
const float *input_data, const core::OMRuntimeShape &input_shape, float *dloss_dweight_data,
Expand All @@ -48,7 +49,7 @@ void inline FullyConnectedWeightGrad(
float cur_dloss_doutput = dloss_doutput_data[o + depth_bounds.first];
for (uint32_t i = 0; i < accum_depth; ++i)
{
dloss_dweight_data[i + o * accum_depth] = cur_dloss_doutput * input_data[i];
dloss_dweight_data[i + o * accum_depth] += cur_dloss_doutput * input_data[i];
}
}

Expand Down
187 changes: 187 additions & 0 deletions onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H
#define ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H

#include "OMStatus.h"
#include "core/OMRuntimeShape.h"
#include "core/OMKernelType.h"

#include "PALUtils.h"
#include "ProcessBroadcastShapes.h"
#include "PALFullyConnectedWeightGrad.h"

namespace onert_micro
{
namespace train
{
namespace pal
{
namespace
{

void calculateGRUWeightGrads(
const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data,
const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data,
float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data,
const float *input_data, float *input_grad_data, float *state_grad_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_fc_shape,
const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape,
const core::OMRuntimeShape &weight_hidden_shape, float *output_data, float *left_logistic_data,
float *left_mul_data, float *right_logistic_data, const float *right_mul_left_input_data,
const float *right_mul_right_input_data, float *tanh_data, const float *middle_mul_left_input,
const float *middle_mul_right_input, float *left_fc_output_grad_buffer,
float *right_fc_output_grad_buffer)
{
int num_elements = output_shape.flatSize();
for (int i = 0; i < num_elements; ++i)
{
// Middle Mul left input grad
float left_middle_mul = output_grad_data[i];
left_middle_mul *= middle_mul_right_input[i];

// Middle Mul right input grad
float right_middle_mul = output_grad_data[i];
right_middle_mul *= middle_mul_left_input[i];

// Tanh` = 1 / (cos(x) ^ 2)
float tanh_grad_value;
{
float tanh = std::tanh(tanh_data[i]);
tanh_grad_value = (1 - tanh * tanh) * right_middle_mul;
}

// Left mul
float left_mul_grad_value = output_grad_data[i] * output_data[i];

// Sub` = -1
// Left Logistic: Logistic` = (exp(-x) * (1 / (1 + exp(-x))) ^ 2)
float left_logistic_grad_value;
{
float log_value = (1 / (1 + std::exp(-left_logistic_data[i])));
left_logistic_grad_value =
log_value * (1 - log_value) * (left_middle_mul + left_mul_grad_value);
}

// Right mul left input
float right_mul_left_input = tanh_grad_value;
right_mul_left_input *= right_mul_right_input_data[i];

// Right mul right input
float right_mul_right_input = tanh_grad_value;
right_mul_right_input *= right_mul_left_input_data[i];

// Right logistic
float right_logistic_grad_value;
{
float log_value = (1 / (1 + std::exp(-right_logistic_data[i])));
right_logistic_grad_value = log_value * (1 - log_value) * right_mul_left_input;
}

// Left concatenation
left_fc_output_grad_buffer[i] = left_logistic_grad_value;
left_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value;
left_fc_output_grad_buffer[i + 2 * num_elements] = right_mul_right_input;

// Right concatenation
right_fc_output_grad_buffer[i] = left_logistic_grad_value;
right_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value;
right_fc_output_grad_buffer[i + 2 * num_elements] = tanh_grad_value;
}

// Left fc weight grad
FullyConnectedWeightGrad(left_fc_output_grad_buffer, output_fc_shape, output_data, output_shape,
weight_input_grad_data, weight_input_shape,
core::OpTrainableRankType::ALL);
// Right fc weight grad
FullyConnectedWeightGrad(right_fc_output_grad_buffer, output_fc_shape, input_data, input_shape,
weight_hidden_grad_data, weight_hidden_shape,
core::OpTrainableRankType::ALL);

// Set state grad to zero
std::memset(state_grad_data, 0, output_shape.flatSize() * sizeof(float));
}

} // namespace

OMStatus GRUWeightGrads(
const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data,
const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data,
float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data,
const float *input_data, float *input_grad_data, float *state_grad_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape,
const core::OMRuntimeShape &weight_input_shape, const core::OMRuntimeShape &weight_hidden_shape,
const core::OMRuntimeShape &output_shape_fc, float *intermediate_buffer,
float *left_fc_output_grad_buffer, float *right_fc_output_grad_buffer)
{
const int32_t time = input_shape.dims(0);

// Init pointers to intermediate values
size_t offset = output_shape.flatSize();

size_t data_type_size = sizeof(float);
const int32_t num_of_intermediate_tensors = 9;
size_t time_offset = num_of_intermediate_tensors * offset;

core::OMRuntimeShape two_dim_input_shape(2);
auto dim_count = input_shape.dimensionsCount();
if (dim_count < 2)
return UnsupportedType;
two_dim_input_shape.setDim(0, input_shape.dims(dim_count - 2));
two_dim_input_shape.setDim(1, input_shape.dims(dim_count - 1));

core::OMRuntimeShape two_dim_output_shape(2);
dim_count = output_shape.dimensionsCount();
if (dim_count < 2)
return UnsupportedType;
two_dim_output_shape.setDim(0, output_shape.dims(dim_count - 2));
two_dim_output_shape.setDim(1, output_shape.dims(dim_count - 1));

std::memset(weight_input_grad_data, 0, output_shape.flatSize() * sizeof(float) * time);
std::memset(weight_hidden_grad_data, 0, input_shape.dims(2) * sizeof(float) * time);

for (int i = 0; i < time; ++i)
{
float *output_data = intermediate_buffer;
float *left_logistic_data = output_data + offset;
float *left_mul_data = left_logistic_data + offset;
float *right_logistic_data = left_mul_data + offset;
float *right_mul_left_input_data = right_logistic_data + offset;
float *right_mul_right_input_data = right_mul_left_input_data + offset;
float *tanh_data = right_mul_right_input_data + offset;
float *middle_mul_left_input = tanh_data + offset;
float *middle_mul_right_input = middle_mul_left_input + offset;

calculateGRUWeightGrads(
output_grad_data, weight_input_data, weight_input_grad_data, weight_hidden_data,
weight_hidden_grad_data, bias_input_data, bias_input_grad_data, bias_hidden_data,
bias_hidden_grad_data, input_data, input_grad_data, state_grad_data, two_dim_input_shape,
output_shape_fc, two_dim_output_shape, weight_input_shape, weight_hidden_shape, output_data,
left_logistic_data, left_mul_data, right_logistic_data, right_mul_left_input_data,
right_mul_right_input_data, tanh_data, middle_mul_left_input, middle_mul_right_input,
left_fc_output_grad_buffer, right_fc_output_grad_buffer);
input_data += input_shape.dims(2);
intermediate_buffer += time_offset;
}
return Ok;
}

} // namespace pal
} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H
2 changes: 2 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ REGISTER_TRAIN_KERNEL(SOFTMAX, Softmax)
REGISTER_TRAIN_KERNEL(RESHAPE, Reshape)
REGISTER_TRAIN_KERNEL(CONV_2D, Conv2D)
REGISTER_TRAIN_KERNEL(MAX_POOL_2D, MaxPool2D)
REGISTER_TRAIN_KERNEL(GRU, GRU)
REGISTER_TRAIN_KERNEL(STRIDED_SLICE, StridedSlice)
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode)
{
case circle::BuiltinOperator_FULLY_CONNECTED:
case circle::BuiltinOperator_CONV_2D:
case circle::BuiltinOperator_GRU:
return true;
default:
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode)
{
case circle::BuiltinOperator_FULLY_CONNECTED:
case circle::BuiltinOperator_CONV_2D:
case circle::BuiltinOperator_GRU:
return true;
default:
return false;
Expand Down
7 changes: 6 additions & 1 deletion onert-micro/onert-micro/src/train/OMBackpropExecute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecut
args.is_last_layer = false;
}

if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end())
if (trainable_ops_config.empty())
{
args.is_trainable_layer = true;
args.train_rank_type = core::OpTrainableRankType::ALL;
}
else if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end())
{
args.is_trainable_layer = true;
args.train_rank_type = core::OpTrainableRankType(trainable_ops_config[cur_op_index]);
Expand Down
3 changes: 3 additions & 0 deletions onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE
weight_shape = dynamic_shapes;

// 2. Calculate weight gradient
// Init weight grads with zeros
std::memset(dloss_dweight_data, 0,
output_shape.dims(1) * input_shape.dims(1) * sizeof(float));
pal::FullyConnectedWeightGrad(
core::utils::castInputData<float>(dloss_doutput_data), output_shape,
core::utils::castInputData<float>(input_data), input_shape,
Expand Down
Loading