Skip to content

Commit

Permalink
[Draft][onert-micro] Introduce GRU training
Browse files Browse the repository at this point in the history
This draft introduces GRU training for onert-micro.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]
  • Loading branch information
Artem Balyshev committed Aug 21, 2024
1 parent 2b9d81f commit 28da445
Show file tree
Hide file tree
Showing 31 changed files with 1,321 additions and 39 deletions.
2 changes: 1 addition & 1 deletion onert-micro/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ else ()

message(STATUS "FOUND FlatBuffers")

set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.6/circle_schema.fbs")
set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.8/circle_schema.fbs")

# NOTE Copy circle_schema.fbs as schema.fbs to generate "schema_generated.fbs" instead of "circle_schema_generated.fbs"
add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs"
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/eval-driver/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ int entry(int argc, char **argv)
}

// Do inference.
interpreter.run();
interpreter.run(config);
}

// Get output.
Expand Down
37 changes: 19 additions & 18 deletions onert-micro/eval-driver/TrainingDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,10 @@ int entry(int argc, char **argv)
config.wof_ptr = nullptr;

// Set user defined training settings
const uint32_t training_epochs = 30;
const uint32_t training_epochs = 50;
const float lambda = 0.001f;
const uint32_t BATCH_SIZE = 32;
const uint32_t INPUT_SIZE = 180;
const uint32_t OUTPUT_SIZE = 4;
const uint32_t num_train_layers = 10;
const uint32_t num_train_layers = 4;
const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY;
const onert_micro::OMTrainOptimizer train_optim = onert_micro::ADAM;
const float beta = 0.9;
Expand All @@ -211,6 +209,9 @@ int entry(int argc, char **argv)
onert_micro::OMTrainingInterpreter train_interpreter;
train_interpreter.importTrainModel(circle_model.data(), config);

const uint32_t OUTPUT_SIZE = train_interpreter.getOutputSizeAt(0);
const uint32_t INPUT_SIZE = train_interpreter.getInputSizeAt(0);

// Temporary buffer to read input data from file using BATCH_SIZE
float training_input[BATCH_SIZE * INPUT_SIZE];
float training_target[BATCH_SIZE * OUTPUT_SIZE];
Expand Down Expand Up @@ -263,11 +264,11 @@ int entry(int argc, char **argv)
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::ACCURACY,
reinterpret_cast<void *>(&accuracy), cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
Expand All @@ -276,10 +277,10 @@ int entry(int argc, char **argv)
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS,
reinterpret_cast<void *>(&mse), cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS,
reinterpret_cast<void *>(&mae), cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
Expand Down Expand Up @@ -335,11 +336,11 @@ int entry(int argc, char **argv)
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::ACCURACY,
reinterpret_cast<void *>(&accuracy), cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
Expand All @@ -348,10 +349,10 @@ int entry(int argc, char **argv)
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS,
reinterpret_cast<void *>(&mse), cur_batch_size);
train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS,
reinterpret_cast<void *>(&mae), cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/OMInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OMInterpreter

OMStatus importModel(const char *model_ptr, const OMConfig &config);

OMStatus run();
OMStatus run(const OMConfig &config);

OMStatus reset();

Expand Down
7 changes: 4 additions & 3 deletions onert-micro/onert-micro/include/OMTrainingInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ class OMTrainingInterpreter
// Note: calculation will be done on test_size number of test samples
// Warning: before using evaluateMetric call: 1) importTrainModel; 2) setInput; 3) setTarget
// Note: number of the samples in data should be equal to the test_size
OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size)
OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val,
uint32_t test_size)
{
return _training_runtime_module.evaluateMetric(metric, metric_val, test_size);
return _training_runtime_module.evaluateMetric(config, metric, metric_val, test_size);
}

// To get input and output flat size
Expand All @@ -86,7 +87,7 @@ class OMTrainingInterpreter
// Load current status from checkpoint and save it in current model and in current config
OMStatus loadCheckpoint(OMConfig &config, const char *load_path);

OMStatus run() { return _training_runtime_module.run(); }
OMStatus run(const OMConfig &config) { return _training_runtime_module.run(config); }
OMStatus allocateInputs() { return _training_runtime_module.allocateInputs(); }

void *getInputData(uint32_t position);
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/core/OMRuntimeModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class OMRuntimeModule
~OMRuntimeModule() = default;

OMStatus importModel(const char *model_ptr, const OMConfig &config);
OMStatus run();
OMStatus run(const OMConfig &config);
OMStatus reset();

uint32_t getNumberOfInputs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class OMTrainingRuntimeModule : public OMRuntimeModule
// 2) metric_val should be initialized with some value before calling this method due to
// after calculation for current batch_num (the sequence number of the current sample)
// this value is added to metric_val
OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size);
OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val,
uint32_t test_size);

// Set input data for input with input_index
// Note: number of the samples in data should be equal to the batch_size in config structure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class OMCircleReader
const CircleOperators *operators() const { return _current_subgraph->operators(); }
const CircleValues *inputs() const { return _current_subgraph->inputs(); }
const CircleValues *outputs() const { return _current_subgraph->outputs(); }
const circle::DataFormat data_format() const { return _current_subgraph->data_format(); }
const CircleMetadataSet *metadata() const { return _model->metadata(); }

uint32_t num_subgraph() const { return _model->subgraphs()->size(); }
Expand Down
2 changes: 2 additions & 0 deletions onert-micro/onert-micro/include/execute/OMExecuteArgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct OMExecuteArgs
core::OMRuntimeContext &runtime_context;
uint16_t kernel_index;
core::OMRuntimeModule &runtime_module;
uint16_t num_train_layers = 0;
bool is_train_mode = false;
};

} // namespace execute
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/execute/OMRuntimeKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

#include <cstdint>

constexpr static uint32_t maxInputSize = 5;
constexpr static uint32_t maxInputSize = 6;
constexpr static uint32_t maxOutputSize = 5;

namespace onert_micro
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void inline FullyConnectedWeightGrad(
float cur_dloss_doutput = dloss_doutput_data[o + depth_bounds.first];
for (uint32_t i = 0; i < accum_depth; ++i)
{
dloss_dweight_data[i + o * accum_depth] = cur_dloss_doutput * input_data[i];
dloss_dweight_data[i + o * accum_depth] += cur_dloss_doutput * input_data[i];
}
}

Expand Down
209 changes: 209 additions & 0 deletions onert-micro/onert-micro/include/pal/common/PALGRUCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H
#define ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H

#include "OMStatus.h"
#include "core/OMRuntimeShape.h"

#include "PALUtils.h"
#include "ProcessBroadcastShapes.h"
#include "PALFullyConnected.h"
#include "PALLogistic.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{
namespace
{
void calculateGRU(const float *input_data, const float *weight_input_data,
const float *weight_hidden_data, const float *bias_input_data,
const float *bias_hidden_data, float *output_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape,
const core::OMRuntimeShape &weight_input_shape,
const core::OMRuntimeShape &weight_hidden_shape, float *output_input_data,
float *output_hidden_data, const core::OMRuntimeShape &output_shape_fc,
float *intermediate_buffer)
{
core::FullyConnectedParams op_params{};
// As FC nodes doesn't have any activations inside GRU, let' use just numeric limits
op_params.float_activation_min = std::numeric_limits<float>::lowest();
op_params.float_activation_max = std::numeric_limits<float>::max();
// If intermediate_buffer != nullptr - then it is train mode and we need save intermediate inform
bool is_train_mode = intermediate_buffer != nullptr;
if (is_train_mode)
{
// Copy input for FC Input to calculate weights gradients
std::memcpy(intermediate_buffer, output_data, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
// FC Input
FullyConnected(op_params, output_data, weight_input_shape, weight_input_data, bias_input_data,
output_shape_fc, output_input_data);

// FC Hidden
// Note: input for this FC node will be saved without intermediate buffer
FullyConnected(op_params, input_data, weight_hidden_shape, weight_hidden_data, bias_hidden_data,
output_shape_fc, output_hidden_data);

int num_elements = output_shape_fc.dims(1) / 3;

float *second_hidden_part = output_hidden_data + num_elements;
float *second_input_part = output_input_data + num_elements;

float *third_hidden_part = second_hidden_part + num_elements;
float *third_input_part = second_input_part + num_elements;

// Calculate Left part
for (int i = 0; i < num_elements; ++i)
{
output_input_data[i] += output_hidden_data[i];
}

// If train mode - save logistic input
if (is_train_mode)
{
std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
Logistic(num_elements, output_input_data, output_input_data);

// If train mode - save most left mul input (right input)
if (is_train_mode)
{
std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
// Calculate most left mul
float *most_left_part_final = output_input_data;
float *first_part = output_input_data;
for (int i = 0; i < num_elements; ++i)
{
output_data[i] *= most_left_part_final[i];
first_part[i] = 1.0f - first_part[i];
}

// Calc second part
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] += second_input_part[i];
}
// If train mode - save logistic input
if (is_train_mode)
{
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
Logistic(num_elements, second_hidden_part, second_hidden_part);

// If train mode - save mul input (left and right)
if (is_train_mode)
{
// Left input
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();

// Right input
std::memcpy(intermediate_buffer, third_input_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] *= third_input_part[i];
second_hidden_part[i] += third_hidden_part[i];
}
// If train mode - save tanh input
if (is_train_mode)
{
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] = std::tanh(second_hidden_part[i]);
}

// If train mode - save mul input (left and right)
if (is_train_mode)
{
// Left input
std::memcpy(intermediate_buffer, first_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();

// Right input
std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float));
// Move intermediate_buffer pointer
intermediate_buffer += output_shape.flatSize();
}
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] *= first_part[i];
output_data[i] += second_hidden_part[i];
}
}

} // namespace

OMStatus GRU(const float *input_data, const float *weight_input_data,
const float *weight_hidden_data, const float *bias_input_data,
const float *bias_hidden_data, const float *hidden_state_data, float *output_data,
float *output_input_data, float *output_hidden_data,
const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape,
const core::OMRuntimeShape &weight_input_shape,
const core::OMRuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size,
float *intermediate_buffer)
{
const int32_t time = input_shape.dims(0);

core::OMRuntimeShape output_shape_fc(2);
output_shape_fc.setDim(0, 1);
output_shape_fc.setDim(1, weight_hidden_shape.dims(0));

std::memcpy(output_data, hidden_state_data, output_shape.flatSize() * sizeof(float));

for (int i = 0; i < time; ++i)
{
calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data,
bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape,
weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc,
intermediate_buffer);
input_data += input_shape.dims(2);
if (intermediate_buffer_size != 0)
{
assert(intermediate_buffer != nullptr);
intermediate_buffer += intermediate_buffer_size;
}
}
return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H
Loading

0 comments on commit 28da445

Please sign in to comment.