diff --git a/onert-micro/CMakeLists.txt b/onert-micro/CMakeLists.txt index d9388173cf4..a43f7e979e7 100644 --- a/onert-micro/CMakeLists.txt +++ b/onert-micro/CMakeLists.txt @@ -70,7 +70,7 @@ else () message(STATUS "FOUND FlatBuffers") - set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.6/circle_schema.fbs") + set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.8/circle_schema.fbs") # NOTE Copy circle_schema.fbs as schema.fbs to generate "schema_generated.fbs" instead of "circle_schema_generated.fbs" add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs" diff --git a/onert-micro/eval-driver/Driver.cpp b/onert-micro/eval-driver/Driver.cpp index b4570e5776e..9049e9eeb56 100644 --- a/onert-micro/eval-driver/Driver.cpp +++ b/onert-micro/eval-driver/Driver.cpp @@ -114,7 +114,7 @@ int entry(int argc, char **argv) } // Do inference. - interpreter.run(); + interpreter.run(config); } // Get output. diff --git a/onert-micro/eval-driver/TrainingDriver.cpp b/onert-micro/eval-driver/TrainingDriver.cpp index 706ad7ecf88..2e3b5c71465 100644 --- a/onert-micro/eval-driver/TrainingDriver.cpp +++ b/onert-micro/eval-driver/TrainingDriver.cpp @@ -179,12 +179,10 @@ int entry(int argc, char **argv) config.wof_ptr = nullptr; // Set user defined training settings - const uint32_t training_epochs = 30; + const uint32_t training_epochs = 50; const float lambda = 0.001f; const uint32_t BATCH_SIZE = 32; - const uint32_t INPUT_SIZE = 180; - const uint32_t OUTPUT_SIZE = 4; - const uint32_t num_train_layers = 10; + const uint32_t num_train_layers = 4; const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY; const onert_micro::OMTrainOptimizer train_optim = onert_micro::ADAM; const float beta = 0.9; @@ -211,6 +209,9 @@ int entry(int argc, char **argv) onert_micro::OMTrainingInterpreter train_interpreter; train_interpreter.importTrainModel(circle_model.data(), config); + const uint32_t OUTPUT_SIZE = train_interpreter.getOutputSizeAt(0); + const uint32_t INPUT_SIZE = train_interpreter.getInputSizeAt(0); + // Temporary buffer to read input data from file using BATCH_SIZE float training_input[BATCH_SIZE * INPUT_SIZE]; float training_target[BATCH_SIZE * OUTPUT_SIZE]; @@ -263,11 +264,11 @@ int entry(int argc, char **argv) if (CLASSIFICATION_TASK) { // Evaluate cross_entropy and accuracy metrics - train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS, + train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS, reinterpret_cast(&cross_entropy_metric), cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast(&accuracy), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::ACCURACY, + reinterpret_cast(&accuracy), cur_batch_size); // Save them into vectors accuracy_v.push_back(accuracy); @@ -276,10 +277,10 @@ int entry(int argc, char **argv) else { // Evaluate mse and mae metrics - train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast(&mse), - cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast(&mae), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS, + reinterpret_cast(&mse), cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS, + reinterpret_cast(&mae), cur_batch_size); // Save them into vectors mse_v.push_back(mse); @@ -335,11 +336,11 @@ int entry(int argc, char **argv) if (CLASSIFICATION_TASK) { // Evaluate cross_entropy and accuracy metrics - train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS, + train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS, reinterpret_cast(&cross_entropy_metric), cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast(&accuracy), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::ACCURACY, + reinterpret_cast(&accuracy), cur_batch_size); // Save them into vectors accuracy_v.push_back(accuracy); @@ -348,10 +349,10 @@ int entry(int argc, char **argv) else { // Evaluate mse and mae metrics - train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast(&mse), - cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast(&mae), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS, + reinterpret_cast(&mse), cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS, + reinterpret_cast(&mae), cur_batch_size); // Save them into vectors mse_v.push_back(mse); diff --git a/onert-micro/onert-micro/include/OMInterpreter.h b/onert-micro/onert-micro/include/OMInterpreter.h index d450c3403e4..cb63cc1bf20 100644 --- a/onert-micro/onert-micro/include/OMInterpreter.h +++ b/onert-micro/onert-micro/include/OMInterpreter.h @@ -40,7 +40,7 @@ class OMInterpreter OMStatus importModel(const char *model_ptr, const OMConfig &config); - OMStatus run(); + OMStatus run(const OMConfig &config); OMStatus reset(); diff --git a/onert-micro/onert-micro/include/OMTrainingInterpreter.h b/onert-micro/onert-micro/include/OMTrainingInterpreter.h index b3dbcd8987b..815908ef70b 100644 --- a/onert-micro/onert-micro/include/OMTrainingInterpreter.h +++ b/onert-micro/onert-micro/include/OMTrainingInterpreter.h @@ -68,9 +68,10 @@ class OMTrainingInterpreter // Note: calculation will be done on test_size number of test samples // Warning: before using evaluateMetric call: 1) importTrainModel; 2) setInput; 3) setTarget // Note: number of the samples in data should be equal to the test_size - OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size) + OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val, + uint32_t test_size) { - return _training_runtime_module.evaluateMetric(metric, metric_val, test_size); + return _training_runtime_module.evaluateMetric(config, metric, metric_val, test_size); } // To get input and output flat size @@ -86,7 +87,7 @@ class OMTrainingInterpreter // Load current status from checkpoint and save it in current model and in current config OMStatus loadCheckpoint(OMConfig &config, const char *load_path); - OMStatus run() { return _training_runtime_module.run(); } + OMStatus run(const OMConfig &config) { return _training_runtime_module.run(config); } OMStatus allocateInputs() { return _training_runtime_module.allocateInputs(); } void *getInputData(uint32_t position); diff --git a/onert-micro/onert-micro/include/core/OMRuntimeModule.h b/onert-micro/onert-micro/include/core/OMRuntimeModule.h index 53ba519a1ef..c5e368d7238 100644 --- a/onert-micro/onert-micro/include/core/OMRuntimeModule.h +++ b/onert-micro/onert-micro/include/core/OMRuntimeModule.h @@ -43,7 +43,7 @@ class OMRuntimeModule ~OMRuntimeModule() = default; OMStatus importModel(const char *model_ptr, const OMConfig &config); - OMStatus run(); + OMStatus run(const OMConfig &config); OMStatus reset(); uint32_t getNumberOfInputs(); diff --git a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h index 9b53c3bd888..c3b00ce5450 100644 --- a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h +++ b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h @@ -68,7 +68,8 @@ class OMTrainingRuntimeModule : public OMRuntimeModule // 2) metric_val should be initialized with some value before calling this method due to // after calculation for current batch_num (the sequence number of the current sample) // this value is added to metric_val - OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size); + OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val, + uint32_t test_size); // Set input data for input with input_index // Note: number of the samples in data should be equal to the batch_size in config structure diff --git a/onert-micro/onert-micro/include/core/reader/OMCircleReader.h b/onert-micro/onert-micro/include/core/reader/OMCircleReader.h index 90c1d8acc47..5d32d516c05 100644 --- a/onert-micro/onert-micro/include/core/reader/OMCircleReader.h +++ b/onert-micro/onert-micro/include/core/reader/OMCircleReader.h @@ -55,7 +55,6 @@ class OMCircleReader const CircleOperators *operators() const { return _current_subgraph->operators(); } const CircleValues *inputs() const { return _current_subgraph->inputs(); } const CircleValues *outputs() const { return _current_subgraph->outputs(); } - const circle::DataFormat data_format() const { return _current_subgraph->data_format(); } const CircleMetadataSet *metadata() const { return _model->metadata(); } uint32_t num_subgraph() const { return _model->subgraphs()->size(); } diff --git a/onert-micro/onert-micro/include/execute/OMExecuteArgs.h b/onert-micro/onert-micro/include/execute/OMExecuteArgs.h index 196f68a87a3..5d9b5f30d6e 100644 --- a/onert-micro/onert-micro/include/execute/OMExecuteArgs.h +++ b/onert-micro/onert-micro/include/execute/OMExecuteArgs.h @@ -33,6 +33,8 @@ struct OMExecuteArgs core::OMRuntimeContext &runtime_context; uint16_t kernel_index; core::OMRuntimeModule &runtime_module; + uint16_t num_train_layers = 0; + bool is_train_mode = false; }; } // namespace execute diff --git a/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h b/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h index b6f63cdaa48..e33239f7256 100644 --- a/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h +++ b/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h @@ -23,7 +23,7 @@ #include -constexpr static uint32_t maxInputSize = 5; +constexpr static uint32_t maxInputSize = 6; constexpr static uint32_t maxOutputSize = 5; namespace onert_micro diff --git a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h index 5f134f0e0df..8379b16c0dc 100644 --- a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h +++ b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h @@ -48,7 +48,7 @@ void inline FullyConnectedWeightGrad( float cur_dloss_doutput = dloss_doutput_data[o + depth_bounds.first]; for (uint32_t i = 0; i < accum_depth; ++i) { - dloss_dweight_data[i + o * accum_depth] = cur_dloss_doutput * input_data[i]; + dloss_dweight_data[i + o * accum_depth] += cur_dloss_doutput * input_data[i]; } } diff --git a/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h b/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h new file mode 100644 index 00000000000..945122c32d7 --- /dev/null +++ b/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H + +#include "OMStatus.h" +#include "core/OMRuntimeShape.h" + +#include "PALUtils.h" +#include "ProcessBroadcastShapes.h" +#include "PALFullyConnected.h" +#include "PALLogistic.h" + +namespace onert_micro +{ +namespace execute +{ +namespace pal +{ +namespace +{ +void calculateGRU(const float *input_data, const float *weight_input_data, + const float *weight_hidden_data, const float *bias_input_data, + const float *bias_hidden_data, float *output_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, float *output_input_data, + float *output_hidden_data, const core::OMRuntimeShape &output_shape_fc, + float *intermediate_buffer) +{ + core::FullyConnectedParams op_params{}; + // As FC nodes doesn't have any activations inside GRU, let' use just numeric limits + op_params.float_activation_min = std::numeric_limits::lowest(); + op_params.float_activation_max = std::numeric_limits::max(); + // If intermediate_buffer != nullptr - then it is train mode and we need save intermediate inform + bool is_train_mode = intermediate_buffer != nullptr; + if (is_train_mode) + { + // Copy input for FC Input to calculate weights gradients + std::memcpy(intermediate_buffer, output_data, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + // FC Input + FullyConnected(op_params, output_data, weight_input_shape, weight_input_data, bias_input_data, + output_shape_fc, output_input_data); + + // FC Hidden + // Note: input for this FC node will be saved without intermediate buffer + FullyConnected(op_params, input_data, weight_hidden_shape, weight_hidden_data, bias_hidden_data, + output_shape_fc, output_hidden_data); + + int num_elements = output_shape_fc.dims(1) / 3; + + float *second_hidden_part = output_hidden_data + num_elements; + float *second_input_part = output_input_data + num_elements; + + float *third_hidden_part = second_hidden_part + num_elements; + float *third_input_part = second_input_part + num_elements; + + // Calculate Left part + for (int i = 0; i < num_elements; ++i) + { + output_input_data[i] += output_hidden_data[i]; + } + + // If train mode - save logistic input + if (is_train_mode) + { + std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + Logistic(num_elements, output_input_data, output_input_data); + + // If train mode - save most left mul input (right input) + if (is_train_mode) + { + std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + // Calculate most left mul + float *most_left_part_final = output_input_data; + float *first_part = output_input_data; + for (int i = 0; i < num_elements; ++i) + { + output_data[i] *= most_left_part_final[i]; + first_part[i] = 1.0f - first_part[i]; + } + + // Calc second part + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] += second_input_part[i]; + } + // If train mode - save logistic input + if (is_train_mode) + { + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + Logistic(num_elements, second_hidden_part, second_hidden_part); + + // If train mode - save mul input (left and right) + if (is_train_mode) + { + // Left input + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + + // Right input + std::memcpy(intermediate_buffer, third_input_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] *= third_input_part[i]; + second_hidden_part[i] += third_hidden_part[i]; + } + // If train mode - save tanh input + if (is_train_mode) + { + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] = std::tanh(second_hidden_part[i]); + } + + // If train mode - save mul input (left and right) + if (is_train_mode) + { + // Left input + std::memcpy(intermediate_buffer, first_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + + // Right input + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] *= first_part[i]; + output_data[i] += second_hidden_part[i]; + } +} + +} // namespace + +OMStatus GRU(const float *input_data, const float *weight_input_data, + const float *weight_hidden_data, const float *bias_input_data, + const float *bias_hidden_data, const float *hidden_state_data, float *output_data, + float *output_input_data, float *output_hidden_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size, + float *intermediate_buffer) +{ + const int32_t time = input_shape.dims(0); + + core::OMRuntimeShape output_shape_fc(2); + output_shape_fc.setDim(0, 1); + output_shape_fc.setDim(1, weight_hidden_shape.dims(0)); + + std::memcpy(output_data, hidden_state_data, output_shape.flatSize() * sizeof(float)); + + for (int i = 0; i < time; ++i) + { + calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data, + bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape, + weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc, + intermediate_buffer); + input_data += input_shape.dims(2); + if (intermediate_buffer_size != 0) + { + assert(intermediate_buffer != nullptr); + intermediate_buffer += intermediate_buffer_size; + } + } + return Ok; +} + +} // namespace pal +} // namespace execute +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H diff --git a/onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h b/onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h new file mode 100644 index 00000000000..82155fa7652 --- /dev/null +++ b/onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H + +#include "OMStatus.h" +#include "core/OMRuntimeShape.h" +#include "core/OMKernelType.h" + +#include "PALUtils.h" +#include "ProcessBroadcastShapes.h" +#include "PALFullyConnectedWeightGrad.h" + +namespace onert_micro +{ +namespace train +{ +namespace pal +{ +namespace +{ + +void calculateGRUWeightGrads( + const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data, + const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data, + float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data, + const float *input_data, float *input_grad_data, float *state_grad_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_fc_shape, + const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, float *output_data, float *left_logistic_data, + float *left_mul_data, float *right_logistic_data, const float *right_mul_left_input_data, + const float *right_mul_right_input_data, float *tanh_data, const float *middle_mul_left_input, + const float *middle_mul_right_input, float *left_fc_output_grad_buffer, + float *right_fc_output_grad_buffer) +{ + int num_elements = output_shape.flatSize(); + for (int i = 0; i < num_elements; ++i) + { + // Middle Mul left input grad + float left_middle_mul = output_grad_data[i]; + left_middle_mul *= middle_mul_right_input[i]; + + // Middle Mul right input grad + float right_middle_mul = output_grad_data[i]; + right_middle_mul *= middle_mul_left_input[i]; + + // Tanh` = 1 / (cos(x) ^ 2) + float tanh_grad_value; + { + float tanh = std::tanh(tanh_data[i]); + tanh_grad_value = (1 - tanh * tanh) * right_middle_mul; + } + + // Left mul + float left_mul_grad_value = output_grad_data[i] * output_data[i]; + + // Sub` = -1 + // Left Logistic: Logistic` = (exp(-x) * (1 / (1 + exp(-x))) ^ 2) + float left_logistic_grad_value; + { + float log_value = (1 / (1 + std::exp(-left_logistic_data[i]))); + left_logistic_grad_value = + log_value * (1 - log_value) * (left_middle_mul + left_mul_grad_value); + } + + // Right mul left input + float right_mul_left_input = tanh_grad_value; + right_mul_left_input *= right_mul_right_input_data[i]; + + // Right mul right input + float right_mul_right_input = tanh_grad_value; + right_mul_right_input *= right_mul_left_input_data[i]; + + // Right logistic + float right_logistic_grad_value; + { + float log_value = (1 / (1 + std::exp(-right_logistic_data[i]))); + right_logistic_grad_value = log_value * (1 - log_value) * right_mul_left_input; + } + + // Left concatenation + left_fc_output_grad_buffer[i] = left_logistic_grad_value; + left_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value; + left_fc_output_grad_buffer[i + 2 * num_elements] = right_mul_right_input; + + // Right concatenation + right_fc_output_grad_buffer[i] = left_logistic_grad_value; + right_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value; + right_fc_output_grad_buffer[i + 2 * num_elements] = tanh_grad_value; + } + + // Left fc weight grad + FullyConnectedWeightGrad(left_fc_output_grad_buffer, output_fc_shape, output_data, output_shape, + weight_input_grad_data, weight_input_shape, + core::OpTrainableRankType::ALL); + // Right fc weight grad + FullyConnectedWeightGrad(right_fc_output_grad_buffer, output_fc_shape, input_data, input_shape, + weight_hidden_grad_data, weight_hidden_shape, + core::OpTrainableRankType::ALL); + + // Set state grad to zero + std::memset(state_grad_data, 0, output_shape.flatSize() * sizeof(float)); + // TODO: calculate input grads +} + +} // namespace + +OMStatus GRUGrads(const float *output_grad_data, const float *weight_input_data, + float *weight_input_grad_data, const float *weight_hidden_data, + float *weight_hidden_grad_data, const float *bias_input_data, + float *bias_input_grad_data, const float *bias_hidden_data, + float *bias_hidden_grad_data, const float *input_data, float *input_grad_data, + float *state_grad_data, const core::OMRuntimeShape &input_shape, + const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, + const core::OMRuntimeShape &output_shape_fc, float *intermediate_buffer, + float *left_fc_output_grad_buffer, float *right_fc_output_grad_buffer) +{ + const int32_t time = input_shape.dims(0); + + // Init pointers to intermediate values + size_t offset = output_shape.flatSize(); + + size_t data_type_size = sizeof(float); + const int32_t num_of_intermediate_tensors = 9; + size_t time_offset = num_of_intermediate_tensors * offset; + + core::OMRuntimeShape two_dim_input_shape(2); + auto dim_count = input_shape.dimensionsCount(); + if (dim_count < 2) + return UnsupportedType; + two_dim_input_shape.setDim(0, input_shape.dims(dim_count - 2)); + two_dim_input_shape.setDim(1, input_shape.dims(dim_count - 1)); + + core::OMRuntimeShape two_dim_output_shape(2); + dim_count = output_shape.dimensionsCount(); + if (dim_count < 2) + return UnsupportedType; + two_dim_output_shape.setDim(0, output_shape.dims(dim_count - 2)); + two_dim_output_shape.setDim(1, output_shape.dims(dim_count - 1)); + + std::memset(weight_input_grad_data, 0, output_shape.flatSize() * sizeof(float) * time); + std::memset(weight_hidden_grad_data, 0, input_shape.dims(2) * sizeof(float) * time); + + for (int i = 0; i < time; ++i) + { + float *output_data = intermediate_buffer; + float *left_logistic_data = output_data + offset; + float *left_mul_data = left_logistic_data + offset; + float *right_logistic_data = left_mul_data + offset; + float *right_mul_left_input_data = right_logistic_data + offset; + float *right_mul_right_input_data = right_mul_left_input_data + offset; + float *tanh_data = right_mul_right_input_data + offset; + float *middle_mul_left_input = tanh_data + offset; + float *middle_mul_right_input = middle_mul_left_input + offset; + + calculateGRUWeightGrads( + output_grad_data, weight_input_data, weight_input_grad_data, weight_hidden_data, + weight_hidden_grad_data, bias_input_data, bias_input_grad_data, bias_hidden_data, + bias_hidden_grad_data, input_data, input_grad_data, state_grad_data, two_dim_input_shape, + output_shape_fc, two_dim_output_shape, weight_input_shape, weight_hidden_shape, output_data, + left_logistic_data, left_mul_data, right_logistic_data, right_mul_left_input_data, + right_mul_right_input_data, tanh_data, middle_mul_left_input, middle_mul_right_input, + left_fc_output_grad_buffer, right_fc_output_grad_buffer); + input_data += input_shape.dims(2); + intermediate_buffer += time_offset; + } + return Ok; +} + +} // namespace pal +} // namespace train +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst index 2560554777a..836c2a246c8 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst @@ -23,6 +23,7 @@ REGISTER_KERNEL(GATHER_ND, GatherND) REGISTER_KERNEL(EXP, Exp) REGISTER_KERNEL(GREATER, Greater) REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual) +REGISTER_KERNEL(GRU, GRU) REGISTER_KERNEL(EXPAND_DIMS, ExpandDims) REGISTER_KERNEL(ELU, Elu) REGISTER_KERNEL(EQUAL, Equal) diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst index 26476df574f..508cd246dbe 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst @@ -3,3 +3,5 @@ REGISTER_TRAIN_KERNEL(SOFTMAX, Softmax) REGISTER_TRAIN_KERNEL(RESHAPE, Reshape) REGISTER_TRAIN_KERNEL(CONV_2D, Conv2D) REGISTER_TRAIN_KERNEL(MAX_POOL_2D, MaxPool2D) +REGISTER_TRAIN_KERNEL(GRU, GRU) +REGISTER_TRAIN_KERNEL(STRIDED_SLICE, StridedSlice) diff --git a/onert-micro/onert-micro/include/pal/mcu/PALGRU.h b/onert-micro/onert-micro/include/pal/mcu/PALGRU.h new file mode 100644 index 00000000000..75389fe5f7e --- /dev/null +++ b/onert-micro/onert-micro/include/pal/mcu/PALGRU.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_H + +#include "PALGRUCommon.h" + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_H diff --git a/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h b/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h new file mode 100644 index 00000000000..fa49b29c1d0 --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H + +#include "TestDataGRUBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +namespace gru_float +{ +/* + * GRU Kernel: + * + * Input(1, 1, 6) + * | + * GRU + * | + * Output(1, 1, 5) + */ +unsigned char test_kernel_model_circle[] = { + 0x1c, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x12, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0xb4, 0x06, 0x00, 0x00, 0x98, 0x06, 0x00, 0x00, + 0xe0, 0x04, 0x00, 0x00, 0x70, 0x03, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0xd8, 0x02, 0x00, 0x00, + 0x90, 0x02, 0x00, 0x00, 0x38, 0x02, 0x00, 0x00, 0xe8, 0x01, 0x00, 0x00, 0xa8, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x6e, 0x6e, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, + 0x65, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0xf4, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0xfb, 0xfb, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x0c, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2d, + 0x2d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0xac, 0x05, 0x00, 0x00, 0xf4, 0x03, 0x00, 0x00, 0x88, 0x02, 0x00, 0x00, 0x2c, 0x02, 0x00, 0x00, + 0xf0, 0x01, 0x00, 0x00, 0xa4, 0x01, 0x00, 0x00, 0x48, 0x01, 0x00, 0x00, 0xf8, 0x00, 0x00, 0x00, + 0xb4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x1a, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, + 0x07, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x01, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc4, 0xfa, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x03, 0x00, 0x00, 0x00, 0x34, 0xfb, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x53, 0x74, 0x61, 0x74, 0x65, 0x66, 0x75, 0x6c, + 0x50, 0x61, 0x72, 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x65, 0x64, 0x43, 0x61, 0x6c, 0x6c, 0x3a, + 0x30, 0x00, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x68, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x73, 0x74, 0x72, 0x69, + 0x64, 0x65, 0x64, 0x5f, 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x32, 0x00, 0x00, 0x00, 0x00, + 0x26, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xb4, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x73, 0x74, 0x72, 0x69, 0x64, 0x65, 0x64, 0x5f, + 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0x72, 0xfd, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, + 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x38, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x73, 0x74, 0x72, 0x69, 0x64, 0x65, 0x64, 0x5f, 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x00, + 0xc6, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x60, 0xfc, 0xff, 0xff, 0x24, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x46, 0x75, 0x73, 0x65, 0x64, 0x43, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x47, + 0x52, 0x55, 0x00, 0x00, 0x3c, 0xfc, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x98, 0xfc, 0xff, 0xff, 0x48, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, + 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x2f, 0x67, 0x72, 0x75, 0x2f, 0x7a, 0x65, 0x72, 0x6f, 0x73, + 0x00, 0x00, 0x00, 0x00, 0x4a, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0xf0, 0xfc, 0xff, 0xff, 0x58, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x77, 0x68, 0x69, 0x6c, 0x65, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, + 0x5f, 0x31, 0x31, 0x00, 0x9a, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x2c, 0x01, 0x00, 0x00, + 0xc0, 0xfb, 0x12, 0xbd, 0x4b, 0xb1, 0x0c, 0x3f, 0x51, 0xbe, 0xa0, 0x3d, 0xdb, 0xcd, 0xca, 0xbe, + 0x77, 0xa7, 0x8d, 0x3e, 0xd8, 0x24, 0xe8, 0x3e, 0xc6, 0xe3, 0xfe, 0x3d, 0xa8, 0x41, 0xf0, 0xbd, + 0x9e, 0x70, 0xf3, 0xbd, 0x50, 0xfc, 0x4b, 0x3e, 0x7f, 0x8b, 0xf0, 0x3d, 0xae, 0xc0, 0x83, 0x3d, + 0xe4, 0xf0, 0x98, 0xbe, 0xd4, 0xd0, 0x7f, 0xbe, 0x80, 0xca, 0x98, 0x39, 0xe6, 0x2c, 0x08, 0xbe, + 0x61, 0x44, 0xdf, 0xbd, 0x67, 0x32, 0x32, 0xbe, 0x6a, 0x61, 0xdf, 0x3e, 0xc3, 0x0c, 0x55, 0x3e, + 0x6c, 0x28, 0x0e, 0xbf, 0xb6, 0x52, 0xf1, 0x3d, 0xb7, 0xd1, 0x3f, 0xbd, 0xa6, 0xf0, 0x9d, 0xbe, + 0xa0, 0xdd, 0xb1, 0x3e, 0xa3, 0x7d, 0x50, 0xbd, 0x3e, 0xd7, 0xe6, 0x3e, 0xe4, 0xb0, 0xe6, 0x3d, + 0x2a, 0xd6, 0xeb, 0x3e, 0xa8, 0xc8, 0x49, 0xbb, 0xdd, 0xdc, 0x6b, 0xbe, 0x66, 0x48, 0xc1, 0x3d, + 0x26, 0x6e, 0x52, 0x3e, 0xfc, 0xd6, 0x64, 0x3d, 0x4f, 0x1d, 0x1f, 0xbf, 0x5f, 0xf0, 0x9e, 0x3e, + 0xe0, 0x6e, 0xad, 0x3c, 0x48, 0x37, 0xe7, 0xbd, 0x36, 0xea, 0x0b, 0xbe, 0x3b, 0x81, 0xf2, 0xbd, + 0x52, 0xe1, 0x56, 0xbc, 0x75, 0x2e, 0xa3, 0xbd, 0x8c, 0x71, 0xc5, 0x3d, 0xf0, 0xaf, 0x0b, 0x3e, + 0x6b, 0x7d, 0xba, 0x3e, 0x4e, 0xbd, 0x93, 0xbe, 0xb3, 0x5c, 0x9c, 0xbe, 0x3c, 0xe2, 0xf3, 0x3c, + 0x39, 0xf1, 0xa0, 0x3d, 0xa0, 0x35, 0x50, 0x3e, 0xfa, 0x87, 0x0e, 0xbe, 0x76, 0xc2, 0x12, 0xbd, + 0x2a, 0xd6, 0x01, 0x3f, 0xa0, 0x77, 0xd0, 0x3c, 0x5a, 0x1f, 0x26, 0x3e, 0x02, 0x59, 0x0b, 0x3e, + 0xef, 0x6c, 0x41, 0xbe, 0x6e, 0x40, 0x4a, 0xbd, 0x2f, 0x89, 0x33, 0x3e, 0x50, 0x54, 0x8a, 0x3e, + 0x4d, 0xbb, 0x9f, 0xbe, 0xfd, 0x54, 0xb3, 0x3e, 0xc8, 0x5b, 0x66, 0xbe, 0xf0, 0xb0, 0x44, 0x3d, + 0x8a, 0x4d, 0x14, 0xbe, 0x9d, 0xf7, 0xd4, 0xbd, 0x38, 0xec, 0xc7, 0xbe, 0xb2, 0x79, 0x76, 0x3e, + 0xb2, 0xc2, 0xdd, 0xbe, 0x44, 0xd9, 0x05, 0xbe, 0x59, 0x34, 0x89, 0x3e, 0x71, 0xf8, 0x2b, 0x3e, + 0x1d, 0x62, 0x24, 0x3f, 0x40, 0xe6, 0x02, 0x3d, 0xef, 0x03, 0xb8, 0x3d, 0x02, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x58, 0xfe, 0xff, 0xff, 0x98, 0x01, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x77, 0x68, 0x69, 0x6c, + 0x65, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x68, 0x01, 0x00, 0x00, 0x9c, 0xb9, 0x96, 0xbe, + 0x30, 0x84, 0xdc, 0x3e, 0xb8, 0xe2, 0xc0, 0x3d, 0x00, 0xce, 0xf1, 0x3a, 0x28, 0xd6, 0xfb, 0x3d, + 0x49, 0x84, 0x95, 0xbe, 0xcc, 0x0d, 0x52, 0x3e, 0x7c, 0x4e, 0x6e, 0xbe, 0xde, 0xda, 0x4c, 0xbe, + 0x84, 0x5e, 0xda, 0x3e, 0x46, 0x2b, 0xd1, 0x3e, 0x78, 0xc8, 0x71, 0xbe, 0x00, 0xfd, 0x53, 0x3d, + 0x28, 0x4e, 0x91, 0x3e, 0x00, 0x46, 0xd6, 0xba, 0x20, 0x06, 0x97, 0xbe, 0xf4, 0x04, 0xdc, 0xbe, + 0xde, 0xf8, 0x05, 0xbf, 0x62, 0x20, 0x1d, 0xbe, 0x28, 0x28, 0xf9, 0x3d, 0xc6, 0xa0, 0x86, 0xbe, + 0xa2, 0x2f, 0x7f, 0xbe, 0xa0, 0xa1, 0x1d, 0xbd, 0x3c, 0x03, 0xb2, 0x3e, 0xe6, 0xe6, 0x7c, 0xbe, + 0x2e, 0x37, 0xbe, 0xbe, 0x84, 0xb2, 0x86, 0xbd, 0x10, 0x19, 0x56, 0x3e, 0x59, 0x86, 0x01, 0x3f, + 0xfc, 0x54, 0x15, 0x3e, 0xc3, 0xbd, 0x07, 0x3f, 0xa0, 0xcb, 0x5f, 0x3e, 0x6c, 0x19, 0xbb, 0x3e, + 0x9c, 0x98, 0x24, 0xbe, 0x40, 0x57, 0xd1, 0xbc, 0xb0, 0x9c, 0xec, 0xbd, 0x90, 0x19, 0xb4, 0x3d, + 0x59, 0x11, 0xe7, 0xbe, 0x04, 0x11, 0xd7, 0xbd, 0x6a, 0xd8, 0x46, 0xbe, 0xb9, 0xf2, 0x01, 0xbf, + 0x40, 0xe0, 0x2e, 0xbd, 0x9e, 0xe6, 0x9a, 0x3e, 0xa0, 0x27, 0xda, 0xbe, 0x39, 0xe9, 0x04, 0x3f, + 0x5c, 0x2f, 0x2d, 0x3e, 0x18, 0x35, 0x95, 0x3e, 0x5c, 0x67, 0x14, 0x3e, 0xd0, 0xb1, 0x92, 0xbd, + 0xa8, 0x99, 0xe2, 0xbd, 0x00, 0x1e, 0x0e, 0x3e, 0x80, 0x85, 0x7a, 0x3c, 0x88, 0xde, 0xde, 0x3e, + 0x0a, 0x10, 0xc9, 0x3e, 0x28, 0x29, 0x3c, 0xbd, 0xbe, 0x3a, 0xfd, 0x3e, 0x36, 0x76, 0xef, 0xbe, + 0x6e, 0x44, 0xb4, 0x3e, 0xdc, 0xd6, 0x9c, 0xbd, 0xf0, 0xed, 0x9a, 0x3e, 0x90, 0x9c, 0x6b, 0x3d, + 0x0c, 0xc3, 0x32, 0x3e, 0x8a, 0x27, 0x1f, 0xbe, 0x00, 0x64, 0x5f, 0x3a, 0x8e, 0x71, 0xcc, 0x3e, + 0xcf, 0xe7, 0xe1, 0xbe, 0xc6, 0x65, 0xb4, 0x3e, 0xa4, 0x65, 0x6d, 0x3e, 0x31, 0xd8, 0x03, 0x3f, + 0x2c, 0x2a, 0xa8, 0xbd, 0x38, 0x1b, 0xac, 0x3e, 0x60, 0xcc, 0x64, 0x3e, 0x18, 0x4c, 0x0e, 0xbd, + 0x82, 0x5e, 0xa2, 0x3e, 0xde, 0x70, 0xb0, 0xbe, 0x46, 0x07, 0xe6, 0xbe, 0xf6, 0x4a, 0xa8, 0xbe, + 0x90, 0xfa, 0x3f, 0x3e, 0x5c, 0x9a, 0xe9, 0xbe, 0x63, 0x1e, 0xd3, 0xbe, 0x20, 0x74, 0x7e, 0x3d, + 0x20, 0x9c, 0x02, 0xbf, 0xf7, 0x65, 0x02, 0x3f, 0xb6, 0x45, 0xdf, 0x3e, 0x4e, 0xc2, 0x48, 0xbe, + 0xe3, 0x90, 0xa9, 0xbe, 0xc8, 0x36, 0xab, 0x3d, 0xca, 0xc0, 0x22, 0xbe, 0xec, 0x99, 0x26, 0x3e, + 0xd0, 0x91, 0x35, 0x3e, 0x02, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, + 0x78, 0x3a, 0x30, 0x00, 0xec, 0xff, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00}; + +const std::vector input_data = {7.899295, -4.584313, -2.9251342, + -2.1820352, -10.649105, 1.3530581}; + +const std::vector reference_output_data = {-0.9979859, -0.90550894, 0.025957875, -0.39570245, + -0.8868108}; + +} // namespace gru_float + +class TestDataFloatGRU : public TestDataGRUBase +{ +public: + TestDataFloatGRU() + { + _input_data = gru_float::input_data; + _reference_output_data = gru_float::reference_output_data; + _test_kernel_model_circle = gru_float::test_kernel_model_circle; + } + + ~TestDataFloatGRU() override = default; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h b/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h new file mode 100644 index 00000000000..5c3da425d4a --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H +#define ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H + +#include "test_models/TestDataBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +template class TestDataGRUBase : public TestDataBase +{ +public: + TestDataGRUBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H diff --git a/onert-micro/onert-micro/src/OMInterpreter.cpp b/onert-micro/onert-micro/src/OMInterpreter.cpp index 1e7ee420d4a..bb7076f6416 100644 --- a/onert-micro/onert-micro/src/OMInterpreter.cpp +++ b/onert-micro/onert-micro/src/OMInterpreter.cpp @@ -27,7 +27,7 @@ OMStatus OMInterpreter::importModel(const char *model_ptr, const OMConfig &confi return _runtime_module.importModel(model_ptr, config); } -OMStatus OMInterpreter::run() { return _runtime_module.run(); } +OMStatus OMInterpreter::run(const OMConfig &config) { return _runtime_module.run(config); } OMStatus OMInterpreter::reset() { return _runtime_module.reset(); } diff --git a/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp b/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp index 857d7413240..f1b1a5b0a1e 100644 --- a/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp +++ b/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp @@ -148,7 +148,7 @@ OMStatus OMRuntimeModule::allocateInputs() return _graphs.at(0).allocateGraphInputs(); } -OMStatus OMRuntimeModule::run() +OMStatus OMRuntimeModule::run(const OMConfig &config) { OMStatus status = Ok; @@ -158,7 +158,11 @@ OMStatus OMRuntimeModule::run() core::OMRuntimeGraph &main_graph = _graphs.at(0); execute::OMExecuteArgs execute_args = {main_graph.getRuntimeStorage(), - main_graph.getRuntimeContext(), 0, *this}; + main_graph.getRuntimeContext(), + 0, + *this, + config.training_context.num_of_train_layers, + config.train_mode}; status = execute::OMKernelExecute::runForward(execute_args, main_graph.getRuntimeAllocator()); if (status != Ok) diff --git a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp index 87e3c27aef1..c3492956539 100644 --- a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp +++ b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp @@ -126,7 +126,7 @@ OMStatus OMTrainingRuntimeModule::trainSingleStep(OMConfig &config) // b. Run forward graph { - status = run(); + status = run(config); assert(status == Ok); if (status != Ok) return status; @@ -203,8 +203,8 @@ OMStatus OMTrainingRuntimeModule::trainSingleStep(OMConfig &config) * after calculation for current batch_num (the sequence number of the current sample) * this value is added to metric_val */ -OMStatus OMTrainingRuntimeModule::evaluateMetric(OMMetrics metric, void *metric_val, - uint32_t test_size) +OMStatus OMTrainingRuntimeModule::evaluateMetric(const OMConfig &config, OMMetrics metric, + void *metric_val, uint32_t test_size) { OMStatus status = Ok; OMRuntimeGraph &forward_main_graph = _graphs.at(0); @@ -234,7 +234,7 @@ OMStatus OMTrainingRuntimeModule::evaluateMetric(OMMetrics metric, void *metric_ // b. Run forward graph { - status = run(); + status = run(config); assert(status == Ok); if (status != Ok) return status; diff --git a/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp b/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp index db61df5a2b1..35c130e74ba 100644 --- a/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp +++ b/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp @@ -49,6 +49,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode) { case circle::BuiltinOperator_FULLY_CONNECTED: case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_GRU: return true; default: return false; diff --git a/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp b/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp index ca32c5eeec6..6d3f781d851 100644 --- a/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp +++ b/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp @@ -34,6 +34,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode) { case circle::BuiltinOperator_FULLY_CONNECTED: case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_GRU: return true; default: return false; diff --git a/onert-micro/onert-micro/src/execute/kernels/GRU.cpp b/onert-micro/onert-micro/src/execute/kernels/GRU.cpp new file mode 100644 index 00000000000..0a86e6c188e --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/GRU.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" +#include "core/memory/OMMemoryManager.h" + +#include "execute/OMKernelExecutionBuilder.h" +#include "execute/OMUtils.h" +#include "execute/OMRuntimeKernel.h" + +#include "PALGRU.h" + +using namespace onert_micro; +using namespace onert_micro::core; +using namespace onert_micro::execute; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +// NOTE: doesnt currently support dynamic shapes +OMStatus onert_micro::execute::execute_kernel_CircleGRU(const OMExecuteArgs &execute_args) +{ + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + + const circle::Tensor *input; + const circle::Tensor *hidden_hidden; + const circle::Tensor *hidden_hidden_bias; + const circle::Tensor *hidden_input; + const circle::Tensor *hidden_input_bias; + const circle::Tensor *state; + + const circle::Tensor *output; + + uint8_t *input_data; + uint8_t *hidden_hidden_data; + uint8_t *hidden_hidden_bias_data; + uint8_t *hidden_input_data; + uint8_t *hidden_input_bias_data; + uint8_t *state_data; + uint8_t *output_data; + + uint16_t state_tensor_index = 0; + + // Read kernel + { + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[inputTensorIdx]; + hidden_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + hidden_hidden_bias = runtime_kernel.inputs[hiddenHiddenBiasTensorIdx]; + hidden_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + hidden_input_bias = runtime_kernel.inputs[hiddenInputBiasTensorIdx]; + state = runtime_kernel.inputs[stateTensorIdx]; + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(input != nullptr); + assert(hidden_hidden != nullptr); + assert(hidden_input != nullptr); + assert(state != nullptr); + // Biases can be nullptr + assert(output != nullptr); + + runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + + input_data = runtime_kernel.inputs_data[inputTensorIdx]; + hidden_hidden_data = runtime_kernel.inputs_data[hiddenHiddenTensorIdx]; + hidden_hidden_bias_data = runtime_kernel.inputs_data[hiddenHiddenBiasTensorIdx]; + hidden_input_data = runtime_kernel.inputs_data[hiddenInputTensorIdx]; + hidden_input_bias_data = runtime_kernel.inputs_data[hiddenInputBiasTensorIdx]; + state_data = runtime_kernel.inputs_data[stateTensorIdx]; + + output_data = runtime_kernel.outputs_data[outputTensorIdx]; + assert(input_data != nullptr); + assert(hidden_hidden_data != nullptr); + assert(hidden_input_data != nullptr); + assert(state_data != nullptr); + // Bias can be nullptr + assert(output_data != nullptr); + + state_tensor_index = runtime_kernel.inputs_index[stateTensorIdx]; + } + + OMStatus status; + + uint8_t *output_hidden_data; + uint8_t *output_input_data; + + status = + core::memory::OMMemoryManager::allocateMemory(core::OMRuntimeShape(hidden_hidden).flatSize() * + sizeof(core::OMDataType(hidden_hidden->type())), + &output_hidden_data); + if (status != Ok) + return status; + status = core::memory::OMMemoryManager::allocateMemory( + core::OMRuntimeShape(hidden_input).flatSize() * sizeof(core::OMDataType(hidden_input->type())), + &output_input_data); + if (status != Ok) + return status; + + // If train mode need to allocate memory for internal intermediate tensors for calculation + // gradients further Number of intermediate tensors + const int32_t num_of_intermediate_tensors = 9; + // Note: size of the intermediate is equal to output size (should be checked during import phase) + const int32_t size_of_intermediate_tensors = core::OMRuntimeShape(output).flatSize(); + assert(size_of_intermediate_tensors > 0); + if (size_of_intermediate_tensors == 0) + return UnknownError; + + const int32_t input_size = core::OMRuntimeShape(input).flatSize(); + const int32_t output_size = size_of_intermediate_tensors; + + // Allocate buffer with following schema: + // times * [output_size * sizeof(data_type), + // num_of_intermediate_tensors * size_of_intermediate_tensors * sizeof(data_type)] + // Note: need to save all necessary intermediate data to calculate gradients + // Deallocation should perform train/GRU kernel + const size_t data_type_size = sizeof(core::OMDataType(input->type())); + const int32_t time = OMRuntimeShape(input).dims(0); + size_t intermediate_buffer_size = 0; + uint8_t *intermediate_buffer = nullptr; + if (execute_args.is_train_mode) + { + const auto num_operators = runtime_context.getCircleOperators()->size(); + + uint32_t num_train_layers = + execute_args.num_train_layers == 0 ? num_operators : execute_args.num_train_layers; + uint32_t last_node_pos = std::min(num_operators, num_train_layers); + uint32_t last_train_op_index = num_operators - last_node_pos; + + if (execute_args.kernel_index >= last_train_op_index) + { + intermediate_buffer_size = num_of_intermediate_tensors * size_of_intermediate_tensors; + + status = core::memory::OMMemoryManager::allocateMemory( + time * intermediate_buffer_size * data_type_size, &intermediate_buffer); + if (status != Ok) + return status; + + // Save its buffer to state tensor index + runtime_storage.saveDataToTensorIndex(intermediate_buffer, state_tensor_index); + } + } + + switch (input->type()) + { +#ifndef DIS_FLOAT + case circle::TensorType_FLOAT32: + { + status = + pal::GRU(core::utils::castInputData(input_data), + core::utils::castInputData(hidden_input_data), + core::utils::castInputData(hidden_hidden_data), + core::utils::castInputData(hidden_input_bias_data), + core::utils::castInputData(hidden_hidden_bias_data), + core::utils::castInputData(state_data), + core::utils::castOutputData(output_data), + core::utils::castOutputData(output_input_data), + core::utils::castOutputData(output_hidden_data), + core::OMRuntimeShape(input), core::OMRuntimeShape(output), + core::OMRuntimeShape(hidden_input), core::OMRuntimeShape(hidden_hidden), + intermediate_buffer_size, core::utils::castOutputData(intermediate_buffer)); + } + break; +#endif // DIS_FLOAT + default: + { + status = UnsupportedType; + assert(false && "Unsupported type."); + } + } + + core::memory::OMMemoryManager::deallocateMemory(output_input_data); + core::memory::OMMemoryManager::deallocateMemory(output_hidden_data); + + return status; +} diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp new file mode 100644 index 00000000000..d9d49621947 --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "execute/OMTestUtils.h" +#include "test_models/gru/FloatGRUKernel.h" + +namespace onert_micro +{ +namespace execute +{ +namespace testing +{ + +using namespace testing; + +class GRUTest : public ::testing::Test +{ + // Do nothing +}; + +TEST_F(GRUTest, Float_P) +{ + onert_micro::test_model::TestDataFloatGRU test_data_kernel; + std::vector output_data_vector = + onert_micro::execute::testing::checkKernel(1, &test_data_kernel); + EXPECT_THAT(output_data_vector, + FloatArrayNear(test_data_kernel.get_output_data_by_index(0), 0.0001f)); +} + +} // namespace testing +} // namespace execute +} // namespace onert_micro diff --git a/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp b/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp index a83efa06fa1..5aad12f25a2 100644 --- a/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp +++ b/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp @@ -34,6 +34,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode) { case circle::BuiltinOperator_FULLY_CONNECTED: case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_GRU: return true; default: return false; @@ -255,8 +256,9 @@ OMStatus OMExecutionPlanCreator::createForwardExecutionPlan( // of the graph) bool need_to_save_input_data = (index >= last_train_op_indx) and - ((trainable_ops_config.find(index) != trainable_ops_config.end() and - trainable_ops_config[index] != ONLY_BIAS) or + ((trainable_ops_config.empty() or + trainable_ops_config.find(index) != trainable_ops_config.end() and + trainable_ops_config[index] != ONLY_BIAS) or isOpNeedSaveInputData(opcode)); // Flag to determine is current operation needed to save output data (is this op in training diff --git a/onert-micro/onert-micro/src/import/kernels/GRU.cpp b/onert-micro/onert-micro/src/import/kernels/GRU.cpp new file mode 100644 index 00000000000..2c1167c98b5 --- /dev/null +++ b/onert-micro/onert-micro/src/import/kernels/GRU.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "import/OMKernelConfigureBuilder.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" + +#include "execute/OMRuntimeKernel.h" + +using namespace onert_micro; +using namespace onert_micro::core; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus onert_micro::import::configure_kernel_CircleGRU(const OMConfigureArgs &config_args) +{ + core::OMRuntimeContext &runtime_context = config_args.runtime_context; + uint16_t op_index = config_args.kernel_index; + + const circle::Tensor *input; + const circle::Tensor *hidden_hidden; + const circle::Tensor *hidden_hidden_bias; + const circle::Tensor *hidden_input; + const circle::Tensor *hidden_input_bias; + const circle::Tensor *state; + + const circle::Tensor *output; + + // Read kernel + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[inputTensorIdx]; + hidden_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + hidden_hidden_bias = runtime_kernel.inputs[hiddenHiddenBiasTensorIdx]; + hidden_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + hidden_input_bias = runtime_kernel.inputs[hiddenInputBiasTensorIdx]; + state = runtime_kernel.inputs[stateTensorIdx]; + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(input != nullptr); + assert(hidden_hidden != nullptr); + assert(hidden_input != nullptr); + assert(state != nullptr); + // Biases can be nullptr + assert(output != nullptr); + + OMStatus status = Ok; + + OMRuntimeShape hidden_hidden_shape(hidden_hidden); + OMRuntimeShape hidden_input_shape(hidden_input); + OMRuntimeShape output_shape(output); + OMRuntimeShape state_shape(state); + + status = utils::checkCondition(hidden_hidden_shape.dims(0) == hidden_input_shape.dims(0)); + if (status != Ok) + return status; + + const int32_t div_factor = 3; + status = + utils::checkCondition(hidden_hidden_shape.dims(0) == + (div_factor * output_shape.dims(output_shape.dimensionsCount() - 1))); + if (status != Ok) + return status; + + status = utils::checkCondition(output_shape.dims(output_shape.dimensionsCount() - 1) == + state_shape.dims(state_shape.dimensionsCount() - 1)); + if (status != Ok) + return status; + + status = utils::checkCondition(input->type() == output->type()); + if (status != Ok) + return status; + + return status; +} diff --git a/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp b/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp index cd37f875a8c..a8f0dd3ac9f 100644 --- a/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp +++ b/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp @@ -93,7 +93,12 @@ OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecut args.is_last_layer = false; } - if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end()) + if (trainable_ops_config.empty()) + { + args.is_trainable_layer = true; + args.train_rank_type = core::OpTrainableRankType::ALL; + } + else if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end()) { args.is_trainable_layer = true; args.train_rank_type = core::OpTrainableRankType(trainable_ops_config[cur_op_index]); diff --git a/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp b/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp index 7190422c3d8..0e4608d225d 100644 --- a/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp +++ b/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp @@ -152,6 +152,8 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE weight_shape = dynamic_shapes; // 2. Calculate weight gradient + std::memset(dloss_dweight_data, 0, + output_shape.dims(1) * input_shape.dims(1) * sizeof(float)); pal::FullyConnectedWeightGrad( core::utils::castInputData(dloss_doutput_data), output_shape, core::utils::castInputData(input_data), input_shape, diff --git a/onert-micro/onert-micro/src/train/kernels/GRU.cpp b/onert-micro/onert-micro/src/train/kernels/GRU.cpp new file mode 100644 index 00000000000..709c4e5d8bf --- /dev/null +++ b/onert-micro/onert-micro/src/train/kernels/GRU.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMDataType.h" +#include "core/memory/OMMemoryManager.h" + +#include "train/OMBackpropExecutionBuilder.h" +#include "execute/OMRuntimeKernel.h" + +#include "PALGRUWeightGrad.h" + +using namespace onert_micro; +using namespace onert_micro::train; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus onert_micro::train::train_kernel_CircleGRU(const OMBackpropExecuteArgs &args) +{ + // Check is it last layer for training + core::OMRuntimeContext &runtime_context = args.backward_context; + core::OMRuntimeStorage &backward_storage = args.backward_storage; + core::OMRuntimeStorage &forward_storage = args.forward_storage; + uint16_t op_index = args.kernel_index; + + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx]; + const circle::Tensor *weight_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + const circle::Tensor *weight_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input != nullptr); + assert(output != nullptr); + + OMStatus status = Ok; + + // Read forward + status = runtime_kernel.getDataFromStorage(op_index, forward_storage, runtime_context); + if (status != Ok) + return status; + uint8_t *input_data = runtime_kernel.inputs_data[inputTensorIdx]; + uint8_t *weight_input_data = runtime_kernel.inputs_data[hiddenInputTensorIdx]; + uint8_t *weight_hidden_data = runtime_kernel.inputs_data[hiddenHiddenTensorIdx]; + uint8_t *bias_input_data = runtime_kernel.inputs_data[hiddenInputBiasTensorIdx]; + uint8_t *bias_hidden_data = runtime_kernel.inputs_data[hiddenHiddenBiasTensorIdx]; + uint8_t *intermediate_buffer = runtime_kernel.inputs_data[stateTensorIdx]; + // Bias_data can be nullptr + assert(input_data != nullptr); + assert(weight_input_data != nullptr); + assert(weight_hidden_data != nullptr); + assert(intermediate_buffer != nullptr); + + // Read backward + status = runtime_kernel.getDataFromStorage(op_index, backward_storage, runtime_context); + uint8_t *output_grad_data = runtime_kernel.outputs_data[outputTensorIdx]; + uint8_t *weight_input_grad_data = runtime_kernel.inputs_data[hiddenInputTensorIdx]; + uint8_t *weight_hidden_grad_data = runtime_kernel.inputs_data[hiddenHiddenTensorIdx]; + uint8_t *bias_input_grad_data = runtime_kernel.inputs_data[hiddenInputBiasTensorIdx]; + uint8_t *bias_hidden_grad_data = runtime_kernel.inputs_data[hiddenHiddenBiasTensorIdx]; + uint8_t *state_grad_data = runtime_kernel.inputs_data[stateTensorIdx]; + uint8_t *input_grad_data = runtime_kernel.inputs_data[inputTensorIdx]; + // Bias_data and input_grad_data can be nullptr + // Note: input_grad_data can be nullptr due to it can be last trainable node + assert(output_grad_data != nullptr); + assert(weight_input_grad_data != nullptr); + assert(weight_hidden_grad_data != nullptr); + assert(state_grad_data != nullptr); + + // Obtain shapes + core::OMRuntimeShape input_shape(input); + core::OMRuntimeShape output_shape(output); + core::OMRuntimeShape weight_input_shape(weight_input); + core::OMRuntimeShape weight_hidden_shape(weight_hidden); + + // Init output shape for FullyConnected layers + core::OMRuntimeShape output_shape_fc(2); + output_shape_fc.setDim(0, 1); + output_shape_fc.setDim(1, weight_hidden_shape.dims(0)); + + // Allocate memory for outputs temporary gradients for FullyConnected layers + uint8_t *left_fc_output_grad_buffer; + uint8_t *right_fc_output_grad_buffer; + // Checking during import + assert(weight_hidden_shape.dims(0) == weight_input_shape.dims(0)); + size_t allocation_size = sizeof(core::OMDataType(input->type())) * weight_hidden_shape.dims(0); + status = + core::memory::OMMemoryManager::allocateMemory(allocation_size, &left_fc_output_grad_buffer); + if (status != Ok) + return status; + status = + core::memory::OMMemoryManager::allocateMemory(allocation_size, &right_fc_output_grad_buffer); + if (status != Ok) + return status; + + assert(left_fc_output_grad_buffer != nullptr and right_fc_output_grad_buffer != nullptr); + + // Currently support only float training + if (input->type() != circle::TensorType_FLOAT32) + return UnsupportedType; + + status = pal::GRUGrads(core::utils::castInputData(output_grad_data), + core::utils::castInputData(weight_input_data), + core::utils::castOutputData(weight_input_grad_data), + core::utils::castInputData(weight_hidden_data), + core::utils::castOutputData(weight_hidden_grad_data), + core::utils::castInputData(bias_input_data), + core::utils::castOutputData(bias_input_grad_data), + core::utils::castInputData(bias_hidden_data), + core::utils::castOutputData(bias_hidden_grad_data), + core::utils::castInputData(input_data), + core::utils::castOutputData(input_grad_data), + core::utils::castOutputData(state_grad_data), input_shape, + output_shape, weight_input_shape, weight_hidden_shape, output_shape_fc, + core::utils::castOutputData(intermediate_buffer), + core::utils::castOutputData(left_fc_output_grad_buffer), + core::utils::castOutputData(right_fc_output_grad_buffer)); + + // Deallocate + core::memory::OMMemoryManager::deallocateMemory(intermediate_buffer); + core::memory::OMMemoryManager::deallocateMemory(left_fc_output_grad_buffer); + core::memory::OMMemoryManager::deallocateMemory(right_fc_output_grad_buffer); + + forward_storage.removeTensorFromTensorIndexToData(runtime_kernel.inputs_index[stateTensorIdx]); + + return status; +} diff --git a/onert-micro/onert-micro/src/train/kernels/StridedSlice.cpp b/onert-micro/onert-micro/src/train/kernels/StridedSlice.cpp new file mode 100644 index 00000000000..b2a8ff4fa14 --- /dev/null +++ b/onert-micro/onert-micro/src/train/kernels/StridedSlice.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMDataType.h" + +#include "train/OMBackpropExecutionBuilder.h" +#include "execute/OMRuntimeKernel.h" + +using namespace onert_micro; +using namespace onert_micro::train; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +/* + * - Calculate input gradient - Optional (not required if it is last op) + * Note: now support when it is just reshape, number of output tensor is one and flat size of the + * output tensor is equal to input + */ +// TODO: support more general part +OMStatus onert_micro::train::train_kernel_CircleStridedSlice(const OMBackpropExecuteArgs &args) +{ + // Check is it last layer for training + if (args.is_last_layer) + return Ok; + + core::OMRuntimeContext &runtime_context = args.backward_context; + core::OMRuntimeStorage &runtime_storage = args.backward_storage; + uint16_t op_index = args.kernel_index; + + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx]; + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input != nullptr); + assert(output != nullptr); + + // Note: now support when it is just reshape, number of output tensor is one and flat size of the + // output tensor is equal to input + assert(runtime_kernel.outputs_num == 1); + const core::OMRuntimeShape shape(input); + const core::OMRuntimeShape output_shape(input); + assert(shape.flatSize() == output_shape.flatSize()); + if (runtime_kernel.outputs_num > 1 or shape.flatSize() != output_shape.flatSize()) + return UnsupportedType; + + OMStatus status = Ok; + + status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + if (status != Ok) + return status; + + uint8_t *input_data = runtime_kernel.inputs_data[inputTensorIdx]; + uint8_t *output_data = runtime_kernel.outputs_data[outputTensorIdx]; + + assert(input_data != nullptr); + assert(output_data != nullptr); + + // Check is it inplace kernel + if (input_data == output_data) + return Ok; + + const size_t element_size = + static_cast(getOMDataTypeSize(core::onertMicroDatatype(input->type()))); + const int32_t num_elements = shape.flatSize(); + std::memcpy(input_data, output_data, num_elements * element_size); + + return status; +}