diff --git a/onert-micro/eval-driver/Driver.cpp b/onert-micro/eval-driver/Driver.cpp index b4570e5776e..9049e9eeb56 100644 --- a/onert-micro/eval-driver/Driver.cpp +++ b/onert-micro/eval-driver/Driver.cpp @@ -114,7 +114,7 @@ int entry(int argc, char **argv) } // Do inference. - interpreter.run(); + interpreter.run(config); } // Get output. diff --git a/onert-micro/onert-micro/include/OMInterpreter.h b/onert-micro/onert-micro/include/OMInterpreter.h index d450c3403e4..cb63cc1bf20 100644 --- a/onert-micro/onert-micro/include/OMInterpreter.h +++ b/onert-micro/onert-micro/include/OMInterpreter.h @@ -40,7 +40,7 @@ class OMInterpreter OMStatus importModel(const char *model_ptr, const OMConfig &config); - OMStatus run(); + OMStatus run(const OMConfig &config); OMStatus reset(); diff --git a/onert-micro/onert-micro/include/OMTrainingInterpreter.h b/onert-micro/onert-micro/include/OMTrainingInterpreter.h index b3dbcd8987b..815908ef70b 100644 --- a/onert-micro/onert-micro/include/OMTrainingInterpreter.h +++ b/onert-micro/onert-micro/include/OMTrainingInterpreter.h @@ -68,9 +68,10 @@ class OMTrainingInterpreter // Note: calculation will be done on test_size number of test samples // Warning: before using evaluateMetric call: 1) importTrainModel; 2) setInput; 3) setTarget // Note: number of the samples in data should be equal to the test_size - OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size) + OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val, + uint32_t test_size) { - return _training_runtime_module.evaluateMetric(metric, metric_val, test_size); + return _training_runtime_module.evaluateMetric(config, metric, metric_val, test_size); } // To get input and output flat size @@ -86,7 +87,7 @@ class OMTrainingInterpreter // Load current status from checkpoint and save it in current model and in current config OMStatus loadCheckpoint(OMConfig &config, const char *load_path); - OMStatus run() { return _training_runtime_module.run(); } + OMStatus run(const OMConfig &config) { return _training_runtime_module.run(config); } OMStatus allocateInputs() { return _training_runtime_module.allocateInputs(); } void *getInputData(uint32_t position); diff --git a/onert-micro/onert-micro/include/core/OMRuntimeModule.h b/onert-micro/onert-micro/include/core/OMRuntimeModule.h index 53ba519a1ef..c5e368d7238 100644 --- a/onert-micro/onert-micro/include/core/OMRuntimeModule.h +++ b/onert-micro/onert-micro/include/core/OMRuntimeModule.h @@ -43,7 +43,7 @@ class OMRuntimeModule ~OMRuntimeModule() = default; OMStatus importModel(const char *model_ptr, const OMConfig &config); - OMStatus run(); + OMStatus run(const OMConfig &config); OMStatus reset(); uint32_t getNumberOfInputs(); diff --git a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h index 9b53c3bd888..c3b00ce5450 100644 --- a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h +++ b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h @@ -68,7 +68,8 @@ class OMTrainingRuntimeModule : public OMRuntimeModule // 2) metric_val should be initialized with some value before calling this method due to // after calculation for current batch_num (the sequence number of the current sample) // this value is added to metric_val - OMStatus evaluateMetric(OMMetrics metric, void *metric_val, uint32_t test_size); + OMStatus evaluateMetric(const OMConfig &config, OMMetrics metric, void *metric_val, + uint32_t test_size); // Set input data for input with input_index // Note: number of the samples in data should be equal to the batch_size in config structure diff --git a/onert-micro/onert-micro/include/execute/OMExecuteArgs.h b/onert-micro/onert-micro/include/execute/OMExecuteArgs.h index 196f68a87a3..feb1318e5b7 100644 --- a/onert-micro/onert-micro/include/execute/OMExecuteArgs.h +++ b/onert-micro/onert-micro/include/execute/OMExecuteArgs.h @@ -33,6 +33,8 @@ struct OMExecuteArgs core::OMRuntimeContext &runtime_context; uint16_t kernel_index; core::OMRuntimeModule &runtime_module; + uint32_t num_train_layers = 0; + bool is_train_mode = false; }; } // namespace execute diff --git a/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h b/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h index b6f63cdaa48..e33239f7256 100644 --- a/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h +++ b/onert-micro/onert-micro/include/execute/OMRuntimeKernel.h @@ -23,7 +23,7 @@ #include -constexpr static uint32_t maxInputSize = 5; +constexpr static uint32_t maxInputSize = 6; constexpr static uint32_t maxOutputSize = 5; namespace onert_micro diff --git a/onert-micro/onert-micro/include/execute/OMTestUtils.h b/onert-micro/onert-micro/include/execute/OMTestUtils.h index eaef8c0d457..e5d685db72a 100644 --- a/onert-micro/onert-micro/include/execute/OMTestUtils.h +++ b/onert-micro/onert-micro/include/execute/OMTestUtils.h @@ -61,7 +61,7 @@ std::vector checkKernel(uint32_t num_inputs, } } - interpreter.run(); + interpreter.run(config); U *output_data = reinterpret_cast(interpreter.getOutputDataAt(0)); const size_t num_elements = interpreter.getOutputSizeAt(0); diff --git a/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h b/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h new file mode 100644 index 00000000000..945122c32d7 --- /dev/null +++ b/onert-micro/onert-micro/include/pal/common/PALGRUCommon.h @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H + +#include "OMStatus.h" +#include "core/OMRuntimeShape.h" + +#include "PALUtils.h" +#include "ProcessBroadcastShapes.h" +#include "PALFullyConnected.h" +#include "PALLogistic.h" + +namespace onert_micro +{ +namespace execute +{ +namespace pal +{ +namespace +{ +void calculateGRU(const float *input_data, const float *weight_input_data, + const float *weight_hidden_data, const float *bias_input_data, + const float *bias_hidden_data, float *output_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, float *output_input_data, + float *output_hidden_data, const core::OMRuntimeShape &output_shape_fc, + float *intermediate_buffer) +{ + core::FullyConnectedParams op_params{}; + // As FC nodes doesn't have any activations inside GRU, let' use just numeric limits + op_params.float_activation_min = std::numeric_limits::lowest(); + op_params.float_activation_max = std::numeric_limits::max(); + // If intermediate_buffer != nullptr - then it is train mode and we need save intermediate inform + bool is_train_mode = intermediate_buffer != nullptr; + if (is_train_mode) + { + // Copy input for FC Input to calculate weights gradients + std::memcpy(intermediate_buffer, output_data, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + // FC Input + FullyConnected(op_params, output_data, weight_input_shape, weight_input_data, bias_input_data, + output_shape_fc, output_input_data); + + // FC Hidden + // Note: input for this FC node will be saved without intermediate buffer + FullyConnected(op_params, input_data, weight_hidden_shape, weight_hidden_data, bias_hidden_data, + output_shape_fc, output_hidden_data); + + int num_elements = output_shape_fc.dims(1) / 3; + + float *second_hidden_part = output_hidden_data + num_elements; + float *second_input_part = output_input_data + num_elements; + + float *third_hidden_part = second_hidden_part + num_elements; + float *third_input_part = second_input_part + num_elements; + + // Calculate Left part + for (int i = 0; i < num_elements; ++i) + { + output_input_data[i] += output_hidden_data[i]; + } + + // If train mode - save logistic input + if (is_train_mode) + { + std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + Logistic(num_elements, output_input_data, output_input_data); + + // If train mode - save most left mul input (right input) + if (is_train_mode) + { + std::memcpy(intermediate_buffer, output_input_data, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + // Calculate most left mul + float *most_left_part_final = output_input_data; + float *first_part = output_input_data; + for (int i = 0; i < num_elements; ++i) + { + output_data[i] *= most_left_part_final[i]; + first_part[i] = 1.0f - first_part[i]; + } + + // Calc second part + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] += second_input_part[i]; + } + // If train mode - save logistic input + if (is_train_mode) + { + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + Logistic(num_elements, second_hidden_part, second_hidden_part); + + // If train mode - save mul input (left and right) + if (is_train_mode) + { + // Left input + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + + // Right input + std::memcpy(intermediate_buffer, third_input_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] *= third_input_part[i]; + second_hidden_part[i] += third_hidden_part[i]; + } + // If train mode - save tanh input + if (is_train_mode) + { + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] = std::tanh(second_hidden_part[i]); + } + + // If train mode - save mul input (left and right) + if (is_train_mode) + { + // Left input + std::memcpy(intermediate_buffer, first_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + + // Right input + std::memcpy(intermediate_buffer, second_hidden_part, output_shape.flatSize() * sizeof(float)); + // Move intermediate_buffer pointer + intermediate_buffer += output_shape.flatSize(); + } + for (int i = 0; i < num_elements; ++i) + { + second_hidden_part[i] *= first_part[i]; + output_data[i] += second_hidden_part[i]; + } +} + +} // namespace + +OMStatus GRU(const float *input_data, const float *weight_input_data, + const float *weight_hidden_data, const float *bias_input_data, + const float *bias_hidden_data, const float *hidden_state_data, float *output_data, + float *output_input_data, float *output_hidden_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size, + float *intermediate_buffer) +{ + const int32_t time = input_shape.dims(0); + + core::OMRuntimeShape output_shape_fc(2); + output_shape_fc.setDim(0, 1); + output_shape_fc.setDim(1, weight_hidden_shape.dims(0)); + + std::memcpy(output_data, hidden_state_data, output_shape.flatSize() * sizeof(float)); + + for (int i = 0; i < time; ++i) + { + calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data, + bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape, + weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc, + intermediate_buffer); + input_data += input_shape.dims(2); + if (intermediate_buffer_size != 0) + { + assert(intermediate_buffer != nullptr); + intermediate_buffer += intermediate_buffer_size; + } + } + return Ok; +} + +} // namespace pal +} // namespace execute +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst index 2560554777a..836c2a246c8 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst @@ -23,6 +23,7 @@ REGISTER_KERNEL(GATHER_ND, GatherND) REGISTER_KERNEL(EXP, Exp) REGISTER_KERNEL(GREATER, Greater) REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual) +REGISTER_KERNEL(GRU, GRU) REGISTER_KERNEL(EXPAND_DIMS, ExpandDims) REGISTER_KERNEL(ELU, Elu) REGISTER_KERNEL(EQUAL, Equal) diff --git a/onert-micro/onert-micro/include/pal/mcu/PALGRU.h b/onert-micro/onert-micro/include/pal/mcu/PALGRU.h new file mode 100644 index 00000000000..75389fe5f7e --- /dev/null +++ b/onert-micro/onert-micro/include/pal/mcu/PALGRU.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_H + +#include "PALGRUCommon.h" + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_H diff --git a/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h b/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h new file mode 100644 index 00000000000..fa49b29c1d0 --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/gru/FloatGRUKernel.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H + +#include "TestDataGRUBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +namespace gru_float +{ +/* + * GRU Kernel: + * + * Input(1, 1, 6) + * | + * GRU + * | + * Output(1, 1, 5) + */ +unsigned char test_kernel_model_circle[] = { + 0x1c, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x12, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0xb4, 0x06, 0x00, 0x00, 0x98, 0x06, 0x00, 0x00, + 0xe0, 0x04, 0x00, 0x00, 0x70, 0x03, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0xd8, 0x02, 0x00, 0x00, + 0x90, 0x02, 0x00, 0x00, 0x38, 0x02, 0x00, 0x00, 0xe8, 0x01, 0x00, 0x00, 0xa8, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x6e, 0x6e, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, + 0x65, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0xf4, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0xfb, 0xfb, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x0c, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2d, + 0x2d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0xac, 0x05, 0x00, 0x00, 0xf4, 0x03, 0x00, 0x00, 0x88, 0x02, 0x00, 0x00, 0x2c, 0x02, 0x00, 0x00, + 0xf0, 0x01, 0x00, 0x00, 0xa4, 0x01, 0x00, 0x00, 0x48, 0x01, 0x00, 0x00, 0xf8, 0x00, 0x00, 0x00, + 0xb4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x1a, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, + 0x07, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x01, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc4, 0xfa, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x03, 0x00, 0x00, 0x00, 0x34, 0xfb, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x53, 0x74, 0x61, 0x74, 0x65, 0x66, 0x75, 0x6c, + 0x50, 0x61, 0x72, 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x65, 0x64, 0x43, 0x61, 0x6c, 0x6c, 0x3a, + 0x30, 0x00, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x68, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x73, 0x74, 0x72, 0x69, + 0x64, 0x65, 0x64, 0x5f, 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x32, 0x00, 0x00, 0x00, 0x00, + 0x26, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xb4, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x73, 0x74, 0x72, 0x69, 0x64, 0x65, 0x64, 0x5f, + 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0x72, 0xfd, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, + 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x38, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x73, 0x74, 0x72, 0x69, 0x64, 0x65, 0x64, 0x5f, 0x73, 0x6c, 0x69, 0x63, 0x65, 0x5f, 0x32, 0x00, + 0xc6, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x60, 0xfc, 0xff, 0xff, 0x24, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x46, 0x75, 0x73, 0x65, 0x64, 0x43, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x47, + 0x52, 0x55, 0x00, 0x00, 0x3c, 0xfc, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x98, 0xfc, 0xff, 0xff, 0x48, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, + 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x2f, 0x67, 0x72, 0x75, 0x2f, 0x7a, 0x65, 0x72, 0x6f, 0x73, + 0x00, 0x00, 0x00, 0x00, 0x4a, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0xf0, 0xfc, 0xff, 0xff, 0x58, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x77, 0x68, 0x69, 0x6c, 0x65, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, + 0x5f, 0x31, 0x31, 0x00, 0x9a, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x2c, 0x01, 0x00, 0x00, + 0xc0, 0xfb, 0x12, 0xbd, 0x4b, 0xb1, 0x0c, 0x3f, 0x51, 0xbe, 0xa0, 0x3d, 0xdb, 0xcd, 0xca, 0xbe, + 0x77, 0xa7, 0x8d, 0x3e, 0xd8, 0x24, 0xe8, 0x3e, 0xc6, 0xe3, 0xfe, 0x3d, 0xa8, 0x41, 0xf0, 0xbd, + 0x9e, 0x70, 0xf3, 0xbd, 0x50, 0xfc, 0x4b, 0x3e, 0x7f, 0x8b, 0xf0, 0x3d, 0xae, 0xc0, 0x83, 0x3d, + 0xe4, 0xf0, 0x98, 0xbe, 0xd4, 0xd0, 0x7f, 0xbe, 0x80, 0xca, 0x98, 0x39, 0xe6, 0x2c, 0x08, 0xbe, + 0x61, 0x44, 0xdf, 0xbd, 0x67, 0x32, 0x32, 0xbe, 0x6a, 0x61, 0xdf, 0x3e, 0xc3, 0x0c, 0x55, 0x3e, + 0x6c, 0x28, 0x0e, 0xbf, 0xb6, 0x52, 0xf1, 0x3d, 0xb7, 0xd1, 0x3f, 0xbd, 0xa6, 0xf0, 0x9d, 0xbe, + 0xa0, 0xdd, 0xb1, 0x3e, 0xa3, 0x7d, 0x50, 0xbd, 0x3e, 0xd7, 0xe6, 0x3e, 0xe4, 0xb0, 0xe6, 0x3d, + 0x2a, 0xd6, 0xeb, 0x3e, 0xa8, 0xc8, 0x49, 0xbb, 0xdd, 0xdc, 0x6b, 0xbe, 0x66, 0x48, 0xc1, 0x3d, + 0x26, 0x6e, 0x52, 0x3e, 0xfc, 0xd6, 0x64, 0x3d, 0x4f, 0x1d, 0x1f, 0xbf, 0x5f, 0xf0, 0x9e, 0x3e, + 0xe0, 0x6e, 0xad, 0x3c, 0x48, 0x37, 0xe7, 0xbd, 0x36, 0xea, 0x0b, 0xbe, 0x3b, 0x81, 0xf2, 0xbd, + 0x52, 0xe1, 0x56, 0xbc, 0x75, 0x2e, 0xa3, 0xbd, 0x8c, 0x71, 0xc5, 0x3d, 0xf0, 0xaf, 0x0b, 0x3e, + 0x6b, 0x7d, 0xba, 0x3e, 0x4e, 0xbd, 0x93, 0xbe, 0xb3, 0x5c, 0x9c, 0xbe, 0x3c, 0xe2, 0xf3, 0x3c, + 0x39, 0xf1, 0xa0, 0x3d, 0xa0, 0x35, 0x50, 0x3e, 0xfa, 0x87, 0x0e, 0xbe, 0x76, 0xc2, 0x12, 0xbd, + 0x2a, 0xd6, 0x01, 0x3f, 0xa0, 0x77, 0xd0, 0x3c, 0x5a, 0x1f, 0x26, 0x3e, 0x02, 0x59, 0x0b, 0x3e, + 0xef, 0x6c, 0x41, 0xbe, 0x6e, 0x40, 0x4a, 0xbd, 0x2f, 0x89, 0x33, 0x3e, 0x50, 0x54, 0x8a, 0x3e, + 0x4d, 0xbb, 0x9f, 0xbe, 0xfd, 0x54, 0xb3, 0x3e, 0xc8, 0x5b, 0x66, 0xbe, 0xf0, 0xb0, 0x44, 0x3d, + 0x8a, 0x4d, 0x14, 0xbe, 0x9d, 0xf7, 0xd4, 0xbd, 0x38, 0xec, 0xc7, 0xbe, 0xb2, 0x79, 0x76, 0x3e, + 0xb2, 0xc2, 0xdd, 0xbe, 0x44, 0xd9, 0x05, 0xbe, 0x59, 0x34, 0x89, 0x3e, 0x71, 0xf8, 0x2b, 0x3e, + 0x1d, 0x62, 0x24, 0x3f, 0x40, 0xe6, 0x02, 0x3d, 0xef, 0x03, 0xb8, 0x3d, 0x02, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x58, 0xfe, 0xff, 0xff, 0x98, 0x01, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x77, 0x68, 0x69, 0x6c, + 0x65, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x68, 0x01, 0x00, 0x00, 0x9c, 0xb9, 0x96, 0xbe, + 0x30, 0x84, 0xdc, 0x3e, 0xb8, 0xe2, 0xc0, 0x3d, 0x00, 0xce, 0xf1, 0x3a, 0x28, 0xd6, 0xfb, 0x3d, + 0x49, 0x84, 0x95, 0xbe, 0xcc, 0x0d, 0x52, 0x3e, 0x7c, 0x4e, 0x6e, 0xbe, 0xde, 0xda, 0x4c, 0xbe, + 0x84, 0x5e, 0xda, 0x3e, 0x46, 0x2b, 0xd1, 0x3e, 0x78, 0xc8, 0x71, 0xbe, 0x00, 0xfd, 0x53, 0x3d, + 0x28, 0x4e, 0x91, 0x3e, 0x00, 0x46, 0xd6, 0xba, 0x20, 0x06, 0x97, 0xbe, 0xf4, 0x04, 0xdc, 0xbe, + 0xde, 0xf8, 0x05, 0xbf, 0x62, 0x20, 0x1d, 0xbe, 0x28, 0x28, 0xf9, 0x3d, 0xc6, 0xa0, 0x86, 0xbe, + 0xa2, 0x2f, 0x7f, 0xbe, 0xa0, 0xa1, 0x1d, 0xbd, 0x3c, 0x03, 0xb2, 0x3e, 0xe6, 0xe6, 0x7c, 0xbe, + 0x2e, 0x37, 0xbe, 0xbe, 0x84, 0xb2, 0x86, 0xbd, 0x10, 0x19, 0x56, 0x3e, 0x59, 0x86, 0x01, 0x3f, + 0xfc, 0x54, 0x15, 0x3e, 0xc3, 0xbd, 0x07, 0x3f, 0xa0, 0xcb, 0x5f, 0x3e, 0x6c, 0x19, 0xbb, 0x3e, + 0x9c, 0x98, 0x24, 0xbe, 0x40, 0x57, 0xd1, 0xbc, 0xb0, 0x9c, 0xec, 0xbd, 0x90, 0x19, 0xb4, 0x3d, + 0x59, 0x11, 0xe7, 0xbe, 0x04, 0x11, 0xd7, 0xbd, 0x6a, 0xd8, 0x46, 0xbe, 0xb9, 0xf2, 0x01, 0xbf, + 0x40, 0xe0, 0x2e, 0xbd, 0x9e, 0xe6, 0x9a, 0x3e, 0xa0, 0x27, 0xda, 0xbe, 0x39, 0xe9, 0x04, 0x3f, + 0x5c, 0x2f, 0x2d, 0x3e, 0x18, 0x35, 0x95, 0x3e, 0x5c, 0x67, 0x14, 0x3e, 0xd0, 0xb1, 0x92, 0xbd, + 0xa8, 0x99, 0xe2, 0xbd, 0x00, 0x1e, 0x0e, 0x3e, 0x80, 0x85, 0x7a, 0x3c, 0x88, 0xde, 0xde, 0x3e, + 0x0a, 0x10, 0xc9, 0x3e, 0x28, 0x29, 0x3c, 0xbd, 0xbe, 0x3a, 0xfd, 0x3e, 0x36, 0x76, 0xef, 0xbe, + 0x6e, 0x44, 0xb4, 0x3e, 0xdc, 0xd6, 0x9c, 0xbd, 0xf0, 0xed, 0x9a, 0x3e, 0x90, 0x9c, 0x6b, 0x3d, + 0x0c, 0xc3, 0x32, 0x3e, 0x8a, 0x27, 0x1f, 0xbe, 0x00, 0x64, 0x5f, 0x3a, 0x8e, 0x71, 0xcc, 0x3e, + 0xcf, 0xe7, 0xe1, 0xbe, 0xc6, 0x65, 0xb4, 0x3e, 0xa4, 0x65, 0x6d, 0x3e, 0x31, 0xd8, 0x03, 0x3f, + 0x2c, 0x2a, 0xa8, 0xbd, 0x38, 0x1b, 0xac, 0x3e, 0x60, 0xcc, 0x64, 0x3e, 0x18, 0x4c, 0x0e, 0xbd, + 0x82, 0x5e, 0xa2, 0x3e, 0xde, 0x70, 0xb0, 0xbe, 0x46, 0x07, 0xe6, 0xbe, 0xf6, 0x4a, 0xa8, 0xbe, + 0x90, 0xfa, 0x3f, 0x3e, 0x5c, 0x9a, 0xe9, 0xbe, 0x63, 0x1e, 0xd3, 0xbe, 0x20, 0x74, 0x7e, 0x3d, + 0x20, 0x9c, 0x02, 0xbf, 0xf7, 0x65, 0x02, 0x3f, 0xb6, 0x45, 0xdf, 0x3e, 0x4e, 0xc2, 0x48, 0xbe, + 0xe3, 0x90, 0xa9, 0xbe, 0xc8, 0x36, 0xab, 0x3d, 0xca, 0xc0, 0x22, 0xbe, 0xec, 0x99, 0x26, 0x3e, + 0xd0, 0x91, 0x35, 0x3e, 0x02, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, + 0x78, 0x3a, 0x30, 0x00, 0xec, 0xff, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00}; + +const std::vector input_data = {7.899295, -4.584313, -2.9251342, + -2.1820352, -10.649105, 1.3530581}; + +const std::vector reference_output_data = {-0.9979859, -0.90550894, 0.025957875, -0.39570245, + -0.8868108}; + +} // namespace gru_float + +class TestDataFloatGRU : public TestDataGRUBase +{ +public: + TestDataFloatGRU() + { + _input_data = gru_float::input_data; + _reference_output_data = gru_float::reference_output_data; + _test_kernel_model_circle = gru_float::test_kernel_model_circle; + } + + ~TestDataFloatGRU() override = default; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_FLOAT_GRU_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h b/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h new file mode 100644 index 00000000000..5c3da425d4a --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/gru/TestDataGRUBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H +#define ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H + +#include "test_models/TestDataBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +template class TestDataGRUBase : public TestDataBase +{ +public: + TestDataGRUBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_GRU_KERNEL_BASE_H diff --git a/onert-micro/onert-micro/src/OMInterpreter.cpp b/onert-micro/onert-micro/src/OMInterpreter.cpp index 1e7ee420d4a..bb7076f6416 100644 --- a/onert-micro/onert-micro/src/OMInterpreter.cpp +++ b/onert-micro/onert-micro/src/OMInterpreter.cpp @@ -27,7 +27,7 @@ OMStatus OMInterpreter::importModel(const char *model_ptr, const OMConfig &confi return _runtime_module.importModel(model_ptr, config); } -OMStatus OMInterpreter::run() { return _runtime_module.run(); } +OMStatus OMInterpreter::run(const OMConfig &config) { return _runtime_module.run(config); } OMStatus OMInterpreter::reset() { return _runtime_module.reset(); } diff --git a/onert-micro/onert-micro/src/api/onert-micro.cpp b/onert-micro/onert-micro/src/api/onert-micro.cpp index 538529b32b9..37f2d63e611 100644 --- a/onert-micro/onert-micro/src/api/onert-micro.cpp +++ b/onert-micro/onert-micro/src/api/onert-micro.cpp @@ -307,7 +307,7 @@ NNFW_STATUS nnfw_session::train_run(bool update_weights) float *user_input_data = (float *)_train_interpreter->getInputData(0); memcpy(allocated_input_data, user_input_data, sizeof(float) * _train_interpreter->getInputSizeAt(0)); - _train_interpreter->run(); + _train_interpreter->run(_config); float *calculated_ptr = (float *)_train_interpreter->getOutputDataAt(0); memcpy(outputbuf, calculated_ptr, sizeof(float) * _train_interpreter->getOutputSizeAt(0)); _train_interpreter->reset(); @@ -380,7 +380,7 @@ NNFW_STATUS nnfw_session::train_get_loss(uint32_t index, float *loss) break; } - _train_interpreter->evaluateMetric(m, reinterpret_cast(loss), + _train_interpreter->evaluateMetric(_config, m, reinterpret_cast(loss), _config.training_context.batch_size); return NNFW_STATUS_NO_ERROR; } diff --git a/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp b/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp index 857d7413240..f1b1a5b0a1e 100644 --- a/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp +++ b/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp @@ -148,7 +148,7 @@ OMStatus OMRuntimeModule::allocateInputs() return _graphs.at(0).allocateGraphInputs(); } -OMStatus OMRuntimeModule::run() +OMStatus OMRuntimeModule::run(const OMConfig &config) { OMStatus status = Ok; @@ -158,7 +158,11 @@ OMStatus OMRuntimeModule::run() core::OMRuntimeGraph &main_graph = _graphs.at(0); execute::OMExecuteArgs execute_args = {main_graph.getRuntimeStorage(), - main_graph.getRuntimeContext(), 0, *this}; + main_graph.getRuntimeContext(), + 0, + *this, + config.training_context.num_of_train_layers, + config.train_mode}; status = execute::OMKernelExecute::runForward(execute_args, main_graph.getRuntimeAllocator()); if (status != Ok) diff --git a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp index 87e3c27aef1..c3492956539 100644 --- a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp +++ b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp @@ -126,7 +126,7 @@ OMStatus OMTrainingRuntimeModule::trainSingleStep(OMConfig &config) // b. Run forward graph { - status = run(); + status = run(config); assert(status == Ok); if (status != Ok) return status; @@ -203,8 +203,8 @@ OMStatus OMTrainingRuntimeModule::trainSingleStep(OMConfig &config) * after calculation for current batch_num (the sequence number of the current sample) * this value is added to metric_val */ -OMStatus OMTrainingRuntimeModule::evaluateMetric(OMMetrics metric, void *metric_val, - uint32_t test_size) +OMStatus OMTrainingRuntimeModule::evaluateMetric(const OMConfig &config, OMMetrics metric, + void *metric_val, uint32_t test_size) { OMStatus status = Ok; OMRuntimeGraph &forward_main_graph = _graphs.at(0); @@ -234,7 +234,7 @@ OMStatus OMTrainingRuntimeModule::evaluateMetric(OMMetrics metric, void *metric_ // b. Run forward graph { - status = run(); + status = run(config); assert(status == Ok); if (status != Ok) return status; diff --git a/onert-micro/onert-micro/src/execute/kernels/GRU.cpp b/onert-micro/onert-micro/src/execute/kernels/GRU.cpp new file mode 100644 index 00000000000..0a86e6c188e --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/GRU.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" +#include "core/memory/OMMemoryManager.h" + +#include "execute/OMKernelExecutionBuilder.h" +#include "execute/OMUtils.h" +#include "execute/OMRuntimeKernel.h" + +#include "PALGRU.h" + +using namespace onert_micro; +using namespace onert_micro::core; +using namespace onert_micro::execute; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +// NOTE: doesnt currently support dynamic shapes +OMStatus onert_micro::execute::execute_kernel_CircleGRU(const OMExecuteArgs &execute_args) +{ + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + + const circle::Tensor *input; + const circle::Tensor *hidden_hidden; + const circle::Tensor *hidden_hidden_bias; + const circle::Tensor *hidden_input; + const circle::Tensor *hidden_input_bias; + const circle::Tensor *state; + + const circle::Tensor *output; + + uint8_t *input_data; + uint8_t *hidden_hidden_data; + uint8_t *hidden_hidden_bias_data; + uint8_t *hidden_input_data; + uint8_t *hidden_input_bias_data; + uint8_t *state_data; + uint8_t *output_data; + + uint16_t state_tensor_index = 0; + + // Read kernel + { + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[inputTensorIdx]; + hidden_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + hidden_hidden_bias = runtime_kernel.inputs[hiddenHiddenBiasTensorIdx]; + hidden_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + hidden_input_bias = runtime_kernel.inputs[hiddenInputBiasTensorIdx]; + state = runtime_kernel.inputs[stateTensorIdx]; + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(input != nullptr); + assert(hidden_hidden != nullptr); + assert(hidden_input != nullptr); + assert(state != nullptr); + // Biases can be nullptr + assert(output != nullptr); + + runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + + input_data = runtime_kernel.inputs_data[inputTensorIdx]; + hidden_hidden_data = runtime_kernel.inputs_data[hiddenHiddenTensorIdx]; + hidden_hidden_bias_data = runtime_kernel.inputs_data[hiddenHiddenBiasTensorIdx]; + hidden_input_data = runtime_kernel.inputs_data[hiddenInputTensorIdx]; + hidden_input_bias_data = runtime_kernel.inputs_data[hiddenInputBiasTensorIdx]; + state_data = runtime_kernel.inputs_data[stateTensorIdx]; + + output_data = runtime_kernel.outputs_data[outputTensorIdx]; + assert(input_data != nullptr); + assert(hidden_hidden_data != nullptr); + assert(hidden_input_data != nullptr); + assert(state_data != nullptr); + // Bias can be nullptr + assert(output_data != nullptr); + + state_tensor_index = runtime_kernel.inputs_index[stateTensorIdx]; + } + + OMStatus status; + + uint8_t *output_hidden_data; + uint8_t *output_input_data; + + status = + core::memory::OMMemoryManager::allocateMemory(core::OMRuntimeShape(hidden_hidden).flatSize() * + sizeof(core::OMDataType(hidden_hidden->type())), + &output_hidden_data); + if (status != Ok) + return status; + status = core::memory::OMMemoryManager::allocateMemory( + core::OMRuntimeShape(hidden_input).flatSize() * sizeof(core::OMDataType(hidden_input->type())), + &output_input_data); + if (status != Ok) + return status; + + // If train mode need to allocate memory for internal intermediate tensors for calculation + // gradients further Number of intermediate tensors + const int32_t num_of_intermediate_tensors = 9; + // Note: size of the intermediate is equal to output size (should be checked during import phase) + const int32_t size_of_intermediate_tensors = core::OMRuntimeShape(output).flatSize(); + assert(size_of_intermediate_tensors > 0); + if (size_of_intermediate_tensors == 0) + return UnknownError; + + const int32_t input_size = core::OMRuntimeShape(input).flatSize(); + const int32_t output_size = size_of_intermediate_tensors; + + // Allocate buffer with following schema: + // times * [output_size * sizeof(data_type), + // num_of_intermediate_tensors * size_of_intermediate_tensors * sizeof(data_type)] + // Note: need to save all necessary intermediate data to calculate gradients + // Deallocation should perform train/GRU kernel + const size_t data_type_size = sizeof(core::OMDataType(input->type())); + const int32_t time = OMRuntimeShape(input).dims(0); + size_t intermediate_buffer_size = 0; + uint8_t *intermediate_buffer = nullptr; + if (execute_args.is_train_mode) + { + const auto num_operators = runtime_context.getCircleOperators()->size(); + + uint32_t num_train_layers = + execute_args.num_train_layers == 0 ? num_operators : execute_args.num_train_layers; + uint32_t last_node_pos = std::min(num_operators, num_train_layers); + uint32_t last_train_op_index = num_operators - last_node_pos; + + if (execute_args.kernel_index >= last_train_op_index) + { + intermediate_buffer_size = num_of_intermediate_tensors * size_of_intermediate_tensors; + + status = core::memory::OMMemoryManager::allocateMemory( + time * intermediate_buffer_size * data_type_size, &intermediate_buffer); + if (status != Ok) + return status; + + // Save its buffer to state tensor index + runtime_storage.saveDataToTensorIndex(intermediate_buffer, state_tensor_index); + } + } + + switch (input->type()) + { +#ifndef DIS_FLOAT + case circle::TensorType_FLOAT32: + { + status = + pal::GRU(core::utils::castInputData(input_data), + core::utils::castInputData(hidden_input_data), + core::utils::castInputData(hidden_hidden_data), + core::utils::castInputData(hidden_input_bias_data), + core::utils::castInputData(hidden_hidden_bias_data), + core::utils::castInputData(state_data), + core::utils::castOutputData(output_data), + core::utils::castOutputData(output_input_data), + core::utils::castOutputData(output_hidden_data), + core::OMRuntimeShape(input), core::OMRuntimeShape(output), + core::OMRuntimeShape(hidden_input), core::OMRuntimeShape(hidden_hidden), + intermediate_buffer_size, core::utils::castOutputData(intermediate_buffer)); + } + break; +#endif // DIS_FLOAT + default: + { + status = UnsupportedType; + assert(false && "Unsupported type."); + } + } + + core::memory::OMMemoryManager::deallocateMemory(output_input_data); + core::memory::OMMemoryManager::deallocateMemory(output_hidden_data); + + return status; +} diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp new file mode 100644 index 00000000000..d9d49621947 --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/tests/GRU.test.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "execute/OMTestUtils.h" +#include "test_models/gru/FloatGRUKernel.h" + +namespace onert_micro +{ +namespace execute +{ +namespace testing +{ + +using namespace testing; + +class GRUTest : public ::testing::Test +{ + // Do nothing +}; + +TEST_F(GRUTest, Float_P) +{ + onert_micro::test_model::TestDataFloatGRU test_data_kernel; + std::vector output_data_vector = + onert_micro::execute::testing::checkKernel(1, &test_data_kernel); + EXPECT_THAT(output_data_vector, + FloatArrayNear(test_data_kernel.get_output_data_by_index(0), 0.0001f)); +} + +} // namespace testing +} // namespace execute +} // namespace onert_micro diff --git a/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp b/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp index a83efa06fa1..5aad12f25a2 100644 --- a/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp +++ b/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp @@ -34,6 +34,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode) { case circle::BuiltinOperator_FULLY_CONNECTED: case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_GRU: return true; default: return false; @@ -255,8 +256,9 @@ OMStatus OMExecutionPlanCreator::createForwardExecutionPlan( // of the graph) bool need_to_save_input_data = (index >= last_train_op_indx) and - ((trainable_ops_config.find(index) != trainable_ops_config.end() and - trainable_ops_config[index] != ONLY_BIAS) or + ((trainable_ops_config.empty() or + trainable_ops_config.find(index) != trainable_ops_config.end() and + trainable_ops_config[index] != ONLY_BIAS) or isOpNeedSaveInputData(opcode)); // Flag to determine is current operation needed to save output data (is this op in training diff --git a/onert-micro/onert-micro/src/import/kernels/GRU.cpp b/onert-micro/onert-micro/src/import/kernels/GRU.cpp new file mode 100644 index 00000000000..2c1167c98b5 --- /dev/null +++ b/onert-micro/onert-micro/src/import/kernels/GRU.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "import/OMKernelConfigureBuilder.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" + +#include "execute/OMRuntimeKernel.h" + +using namespace onert_micro; +using namespace onert_micro::core; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus onert_micro::import::configure_kernel_CircleGRU(const OMConfigureArgs &config_args) +{ + core::OMRuntimeContext &runtime_context = config_args.runtime_context; + uint16_t op_index = config_args.kernel_index; + + const circle::Tensor *input; + const circle::Tensor *hidden_hidden; + const circle::Tensor *hidden_hidden_bias; + const circle::Tensor *hidden_input; + const circle::Tensor *hidden_input_bias; + const circle::Tensor *state; + + const circle::Tensor *output; + + // Read kernel + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[inputTensorIdx]; + hidden_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + hidden_hidden_bias = runtime_kernel.inputs[hiddenHiddenBiasTensorIdx]; + hidden_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + hidden_input_bias = runtime_kernel.inputs[hiddenInputBiasTensorIdx]; + state = runtime_kernel.inputs[stateTensorIdx]; + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(input != nullptr); + assert(hidden_hidden != nullptr); + assert(hidden_input != nullptr); + assert(state != nullptr); + // Biases can be nullptr + assert(output != nullptr); + + OMStatus status = Ok; + + OMRuntimeShape hidden_hidden_shape(hidden_hidden); + OMRuntimeShape hidden_input_shape(hidden_input); + OMRuntimeShape output_shape(output); + OMRuntimeShape state_shape(state); + + status = utils::checkCondition(hidden_hidden_shape.dims(0) == hidden_input_shape.dims(0)); + if (status != Ok) + return status; + + const int32_t div_factor = 3; + status = + utils::checkCondition(hidden_hidden_shape.dims(0) == + (div_factor * output_shape.dims(output_shape.dimensionsCount() - 1))); + if (status != Ok) + return status; + + status = utils::checkCondition(output_shape.dims(output_shape.dimensionsCount() - 1) == + state_shape.dims(state_shape.dimensionsCount() - 1)); + if (status != Ok) + return status; + + status = utils::checkCondition(input->type() == output->type()); + if (status != Ok) + return status; + + return status; +}