diff --git a/onert-micro/eval-driver/TrainingDriver.cpp b/onert-micro/eval-driver/TrainingDriver.cpp index 706ad7ecf88..11d2d399b91 100644 --- a/onert-micro/eval-driver/TrainingDriver.cpp +++ b/onert-micro/eval-driver/TrainingDriver.cpp @@ -179,12 +179,10 @@ int entry(int argc, char **argv) config.wof_ptr = nullptr; // Set user defined training settings - const uint32_t training_epochs = 30; + const uint32_t training_epochs = 50; const float lambda = 0.001f; - const uint32_t BATCH_SIZE = 32; - const uint32_t INPUT_SIZE = 180; - const uint32_t OUTPUT_SIZE = 4; - const uint32_t num_train_layers = 10; + const uint32_t BATCH_SIZE = 64; + const uint32_t num_train_layers = 4; const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY; const onert_micro::OMTrainOptimizer train_optim = onert_micro::ADAM; const float beta = 0.9; @@ -211,6 +209,9 @@ int entry(int argc, char **argv) onert_micro::OMTrainingInterpreter train_interpreter; train_interpreter.importTrainModel(circle_model.data(), config); + const uint32_t OUTPUT_SIZE = train_interpreter.getOutputSizeAt(0); + const uint32_t INPUT_SIZE = train_interpreter.getInputSizeAt(0); + // Temporary buffer to read input data from file using BATCH_SIZE float training_input[BATCH_SIZE * INPUT_SIZE]; float training_target[BATCH_SIZE * OUTPUT_SIZE]; @@ -263,11 +264,11 @@ int entry(int argc, char **argv) if (CLASSIFICATION_TASK) { // Evaluate cross_entropy and accuracy metrics - train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS, + train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS, reinterpret_cast(&cross_entropy_metric), cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast(&accuracy), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::ACCURACY, + reinterpret_cast(&accuracy), cur_batch_size); // Save them into vectors accuracy_v.push_back(accuracy); @@ -276,10 +277,10 @@ int entry(int argc, char **argv) else { // Evaluate mse and mae metrics - train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast(&mse), - cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast(&mae), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS, + reinterpret_cast(&mse), cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS, + reinterpret_cast(&mae), cur_batch_size); // Save them into vectors mse_v.push_back(mse); @@ -335,11 +336,11 @@ int entry(int argc, char **argv) if (CLASSIFICATION_TASK) { // Evaluate cross_entropy and accuracy metrics - train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS, + train_interpreter.evaluateMetric(config, onert_micro::CROSS_ENTROPY_METRICS, reinterpret_cast(&cross_entropy_metric), cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast(&accuracy), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::ACCURACY, + reinterpret_cast(&accuracy), cur_batch_size); // Save them into vectors accuracy_v.push_back(accuracy); @@ -348,10 +349,10 @@ int entry(int argc, char **argv) else { // Evaluate mse and mae metrics - train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast(&mse), - cur_batch_size); - train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast(&mae), - cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MSE_METRICS, + reinterpret_cast(&mse), cur_batch_size); + train_interpreter.evaluateMetric(config, onert_micro::MAE_METRICS, + reinterpret_cast(&mae), cur_batch_size); // Save them into vectors mse_v.push_back(mse); diff --git a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h index 5f134f0e0df..f8a2e2ca9fd 100644 --- a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h +++ b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h @@ -30,6 +30,7 @@ namespace train namespace pal { +// Note: dloss_dweight_data should be initialized void inline FullyConnectedWeightGrad( const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *input_data, const core::OMRuntimeShape &input_shape, float *dloss_dweight_data, @@ -48,7 +49,7 @@ void inline FullyConnectedWeightGrad( float cur_dloss_doutput = dloss_doutput_data[o + depth_bounds.first]; for (uint32_t i = 0; i < accum_depth; ++i) { - dloss_dweight_data[i + o * accum_depth] = cur_dloss_doutput * input_data[i]; + dloss_dweight_data[i + o * accum_depth] += cur_dloss_doutput * input_data[i]; } } diff --git a/onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h b/onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h new file mode 100644 index 00000000000..1f3cd750c29 --- /dev/null +++ b/onert-micro/onert-micro/include/pal/common/PALGRUWeightGrad.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H +#define ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H + +#include "OMStatus.h" +#include "core/OMRuntimeShape.h" +#include "core/OMKernelType.h" + +#include "PALUtils.h" +#include "ProcessBroadcastShapes.h" +#include "PALFullyConnectedWeightGrad.h" + +namespace onert_micro +{ +namespace train +{ +namespace pal +{ +namespace +{ + +void calculateGRUWeightGrads( + const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data, + const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data, + float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data, + const float *input_data, float *input_grad_data, float *state_grad_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_fc_shape, + const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape, + const core::OMRuntimeShape &weight_hidden_shape, float *output_data, float *left_logistic_data, + float *left_mul_data, float *right_logistic_data, const float *right_mul_left_input_data, + const float *right_mul_right_input_data, float *tanh_data, const float *middle_mul_left_input, + const float *middle_mul_right_input, float *left_fc_output_grad_buffer, + float *right_fc_output_grad_buffer) +{ + int num_elements = output_shape.flatSize(); + for (int i = 0; i < num_elements; ++i) + { + // Middle Mul left input grad + float left_middle_mul = output_grad_data[i]; + left_middle_mul *= middle_mul_right_input[i]; + + // Middle Mul right input grad + float right_middle_mul = output_grad_data[i]; + right_middle_mul *= middle_mul_left_input[i]; + + // Tanh` = 1 / (cos(x) ^ 2) + float tanh_grad_value; + { + float tanh = std::tanh(tanh_data[i]); + tanh_grad_value = (1 - tanh * tanh) * right_middle_mul; + } + + // Left mul + float left_mul_grad_value = output_grad_data[i] * output_data[i]; + + // Sub` = -1 + // Left Logistic: Logistic` = (exp(-x) * (1 / (1 + exp(-x))) ^ 2) + float left_logistic_grad_value; + { + float log_value = (1 / (1 + std::exp(-left_logistic_data[i]))); + left_logistic_grad_value = + log_value * (1 - log_value) * (left_middle_mul + left_mul_grad_value); + } + + // Right mul left input + float right_mul_left_input = tanh_grad_value; + right_mul_left_input *= right_mul_right_input_data[i]; + + // Right mul right input + float right_mul_right_input = tanh_grad_value; + right_mul_right_input *= right_mul_left_input_data[i]; + + // Right logistic + float right_logistic_grad_value; + { + float log_value = (1 / (1 + std::exp(-right_logistic_data[i]))); + right_logistic_grad_value = log_value * (1 - log_value) * right_mul_left_input; + } + + // Left concatenation + left_fc_output_grad_buffer[i] = left_logistic_grad_value; + left_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value; + left_fc_output_grad_buffer[i + 2 * num_elements] = right_mul_right_input; + + // Right concatenation + right_fc_output_grad_buffer[i] = left_logistic_grad_value; + right_fc_output_grad_buffer[i + num_elements] = right_logistic_grad_value; + right_fc_output_grad_buffer[i + 2 * num_elements] = tanh_grad_value; + } + + // Left fc weight grad + FullyConnectedWeightGrad(left_fc_output_grad_buffer, output_fc_shape, output_data, output_shape, + weight_input_grad_data, weight_input_shape, + core::OpTrainableRankType::ALL); + // Right fc weight grad + FullyConnectedWeightGrad(right_fc_output_grad_buffer, output_fc_shape, input_data, input_shape, + weight_hidden_grad_data, weight_hidden_shape, + core::OpTrainableRankType::ALL); + + // Set state grad to zero + std::memset(state_grad_data, 0, output_shape.flatSize() * sizeof(float)); +} + +} // namespace + +OMStatus GRUWeightGrads( + const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data, + const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data, + float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data, + const float *input_data, float *input_grad_data, float *state_grad_data, + const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, + const core::OMRuntimeShape &weight_input_shape, const core::OMRuntimeShape &weight_hidden_shape, + const core::OMRuntimeShape &output_shape_fc, float *intermediate_buffer, + float *left_fc_output_grad_buffer, float *right_fc_output_grad_buffer) +{ + const int32_t time = input_shape.dims(0); + + // Init pointers to intermediate values + size_t offset = output_shape.flatSize(); + + size_t data_type_size = sizeof(float); + const int32_t num_of_intermediate_tensors = 9; + size_t time_offset = num_of_intermediate_tensors * offset; + + core::OMRuntimeShape two_dim_input_shape(2); + auto dim_count = input_shape.dimensionsCount(); + if (dim_count < 2) + return UnsupportedType; + two_dim_input_shape.setDim(0, input_shape.dims(dim_count - 2)); + two_dim_input_shape.setDim(1, input_shape.dims(dim_count - 1)); + + core::OMRuntimeShape two_dim_output_shape(2); + dim_count = output_shape.dimensionsCount(); + if (dim_count < 2) + return UnsupportedType; + two_dim_output_shape.setDim(0, output_shape.dims(dim_count - 2)); + two_dim_output_shape.setDim(1, output_shape.dims(dim_count - 1)); + + std::memset(weight_input_grad_data, 0, output_shape.flatSize() * sizeof(float) * time); + std::memset(weight_hidden_grad_data, 0, input_shape.dims(2) * sizeof(float) * time); + + for (int i = 0; i < time; ++i) + { + float *output_data = intermediate_buffer; + float *left_logistic_data = output_data + offset; + float *left_mul_data = left_logistic_data + offset; + float *right_logistic_data = left_mul_data + offset; + float *right_mul_left_input_data = right_logistic_data + offset; + float *right_mul_right_input_data = right_mul_left_input_data + offset; + float *tanh_data = right_mul_right_input_data + offset; + float *middle_mul_left_input = tanh_data + offset; + float *middle_mul_right_input = middle_mul_left_input + offset; + + calculateGRUWeightGrads( + output_grad_data, weight_input_data, weight_input_grad_data, weight_hidden_data, + weight_hidden_grad_data, bias_input_data, bias_input_grad_data, bias_hidden_data, + bias_hidden_grad_data, input_data, input_grad_data, state_grad_data, two_dim_input_shape, + output_shape_fc, two_dim_output_shape, weight_input_shape, weight_hidden_shape, output_data, + left_logistic_data, left_mul_data, right_logistic_data, right_mul_left_input_data, + right_mul_right_input_data, tanh_data, middle_mul_left_input, middle_mul_right_input, + left_fc_output_grad_buffer, right_fc_output_grad_buffer); + input_data += input_shape.dims(2); + intermediate_buffer += time_offset; + } + return Ok; +} + +} // namespace pal +} // namespace train +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_PAL_GRU_WEIGHT_GRAD_COMMON_H diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst index 26476df574f..508cd246dbe 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToTrain.lst @@ -3,3 +3,5 @@ REGISTER_TRAIN_KERNEL(SOFTMAX, Softmax) REGISTER_TRAIN_KERNEL(RESHAPE, Reshape) REGISTER_TRAIN_KERNEL(CONV_2D, Conv2D) REGISTER_TRAIN_KERNEL(MAX_POOL_2D, MaxPool2D) +REGISTER_TRAIN_KERNEL(GRU, GRU) +REGISTER_TRAIN_KERNEL(STRIDED_SLICE, StridedSlice) diff --git a/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp b/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp index db61df5a2b1..35c130e74ba 100644 --- a/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp +++ b/onert-micro/onert-micro/src/core/train/OMCheckpointLoader.cpp @@ -49,6 +49,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode) { case circle::BuiltinOperator_FULLY_CONNECTED: case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_GRU: return true; default: return false; diff --git a/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp b/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp index ca32c5eeec6..6d3f781d851 100644 --- a/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp +++ b/onert-micro/onert-micro/src/core/train/OMCheckpointSaver.cpp @@ -34,6 +34,7 @@ bool isTrainableWeights(const circle::OperatorCode *opcode) { case circle::BuiltinOperator_FULLY_CONNECTED: case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_GRU: return true; default: return false; diff --git a/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp b/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp index cd37f875a8c..a8f0dd3ac9f 100644 --- a/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp +++ b/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp @@ -93,7 +93,12 @@ OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecut args.is_last_layer = false; } - if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end()) + if (trainable_ops_config.empty()) + { + args.is_trainable_layer = true; + args.train_rank_type = core::OpTrainableRankType::ALL; + } + else if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end()) { args.is_trainable_layer = true; args.train_rank_type = core::OpTrainableRankType(trainable_ops_config[cur_op_index]); diff --git a/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp b/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp index 7190422c3d8..266d5c4b51b 100644 --- a/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp +++ b/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp @@ -152,6 +152,9 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE weight_shape = dynamic_shapes; // 2. Calculate weight gradient + // Init weight grads with zeros + std::memset(dloss_dweight_data, 0, + output_shape.dims(1) * input_shape.dims(1) * sizeof(float)); pal::FullyConnectedWeightGrad( core::utils::castInputData(dloss_doutput_data), output_shape, core::utils::castInputData(input_data), input_shape, diff --git a/onert-micro/onert-micro/src/train/kernels/GRU.cpp b/onert-micro/onert-micro/src/train/kernels/GRU.cpp new file mode 100644 index 00000000000..6e274d996cc --- /dev/null +++ b/onert-micro/onert-micro/src/train/kernels/GRU.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMDataType.h" +#include "core/memory/OMMemoryManager.h" + +#include "train/OMBackpropExecutionBuilder.h" +#include "execute/OMRuntimeKernel.h" + +#include "PALGRUWeightGrad.h" + +using namespace onert_micro; +using namespace onert_micro::train; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t hiddenHiddenTensorIdx = 1; +constexpr uint32_t hiddenHiddenBiasTensorIdx = 2; +constexpr uint32_t hiddenInputTensorIdx = 3; +constexpr uint32_t hiddenInputBiasTensorIdx = 4; +constexpr uint32_t stateTensorIdx = 5; + +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus onert_micro::train::train_kernel_CircleGRU(const OMBackpropExecuteArgs &args) +{ + // Check is it last layer for training + core::OMRuntimeContext &runtime_context = args.backward_context; + core::OMRuntimeStorage &backward_storage = args.backward_storage; + core::OMRuntimeStorage &forward_storage = args.forward_storage; + uint16_t op_index = args.kernel_index; + + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx]; + const circle::Tensor *weight_input = runtime_kernel.inputs[hiddenInputTensorIdx]; + const circle::Tensor *weight_hidden = runtime_kernel.inputs[hiddenHiddenTensorIdx]; + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input != nullptr); + assert(output != nullptr); + + OMStatus status = Ok; + + // Read forward + status = runtime_kernel.getDataFromStorage(op_index, forward_storage, runtime_context); + if (status != Ok) + return status; + uint8_t *input_data = runtime_kernel.inputs_data[inputTensorIdx]; + uint8_t *weight_input_data = runtime_kernel.inputs_data[hiddenInputTensorIdx]; + uint8_t *weight_hidden_data = runtime_kernel.inputs_data[hiddenHiddenTensorIdx]; + uint8_t *bias_input_data = runtime_kernel.inputs_data[hiddenInputBiasTensorIdx]; + uint8_t *bias_hidden_data = runtime_kernel.inputs_data[hiddenHiddenBiasTensorIdx]; + uint8_t *intermediate_buffer = runtime_kernel.inputs_data[stateTensorIdx]; + // Bias_data can be nullptr + assert(input_data != nullptr); + assert(weight_input_data != nullptr); + assert(weight_hidden_data != nullptr); + assert(intermediate_buffer != nullptr); + + // Read backward + status = runtime_kernel.getDataFromStorage(op_index, backward_storage, runtime_context); + uint8_t *output_grad_data = runtime_kernel.outputs_data[outputTensorIdx]; + uint8_t *weight_input_grad_data = runtime_kernel.inputs_data[hiddenInputTensorIdx]; + uint8_t *weight_hidden_grad_data = runtime_kernel.inputs_data[hiddenHiddenTensorIdx]; + uint8_t *bias_input_grad_data = runtime_kernel.inputs_data[hiddenInputBiasTensorIdx]; + uint8_t *bias_hidden_grad_data = runtime_kernel.inputs_data[hiddenHiddenBiasTensorIdx]; + uint8_t *state_grad_data = runtime_kernel.inputs_data[stateTensorIdx]; + uint8_t *input_grad_data = runtime_kernel.inputs_data[inputTensorIdx]; + // Bias_data and input_grad_data can be nullptr + // Note: input_grad_data can be nullptr due to it can be last trainable node + assert(output_grad_data != nullptr); + assert(weight_input_grad_data != nullptr); + assert(weight_hidden_grad_data != nullptr); + assert(state_grad_data != nullptr); + + // Obtain shapes + core::OMRuntimeShape input_shape(input); + core::OMRuntimeShape output_shape(output); + core::OMRuntimeShape weight_input_shape(weight_input); + core::OMRuntimeShape weight_hidden_shape(weight_hidden); + + // Init output shape for FullyConnected layers + core::OMRuntimeShape output_shape_fc(2); + output_shape_fc.setDim(0, 1); + output_shape_fc.setDim(1, weight_hidden_shape.dims(0)); + + // Allocate memory for outputs temporary gradients for FullyConnected layers + uint8_t *left_fc_output_grad_buffer; + uint8_t *right_fc_output_grad_buffer; + // Checking during import + assert(weight_hidden_shape.dims(0) == weight_input_shape.dims(0)); + size_t allocation_size = sizeof(core::OMDataType(input->type())) * weight_hidden_shape.dims(0); + status = + core::memory::OMMemoryManager::allocateMemory(allocation_size, &left_fc_output_grad_buffer); + if (status != Ok) + return status; + status = + core::memory::OMMemoryManager::allocateMemory(allocation_size, &right_fc_output_grad_buffer); + if (status != Ok) + return status; + + assert(left_fc_output_grad_buffer != nullptr and right_fc_output_grad_buffer != nullptr); + + // Currently support only float training + if (input->type() != circle::TensorType_FLOAT32) + return UnsupportedType; + + status = + pal::GRUWeightGrads(core::utils::castInputData(output_grad_data), + core::utils::castInputData(weight_input_data), + core::utils::castOutputData(weight_input_grad_data), + core::utils::castInputData(weight_hidden_data), + core::utils::castOutputData(weight_hidden_grad_data), + core::utils::castInputData(bias_input_data), + core::utils::castOutputData(bias_input_grad_data), + core::utils::castInputData(bias_hidden_data), + core::utils::castOutputData(bias_hidden_grad_data), + core::utils::castInputData(input_data), + core::utils::castOutputData(input_grad_data), + core::utils::castOutputData(state_grad_data), input_shape, + output_shape, weight_input_shape, weight_hidden_shape, output_shape_fc, + core::utils::castOutputData(intermediate_buffer), + core::utils::castOutputData(left_fc_output_grad_buffer), + core::utils::castOutputData(right_fc_output_grad_buffer)); + + // TODO: add input grads calculation + + // Deallocate + core::memory::OMMemoryManager::deallocateMemory(intermediate_buffer); + core::memory::OMMemoryManager::deallocateMemory(left_fc_output_grad_buffer); + core::memory::OMMemoryManager::deallocateMemory(right_fc_output_grad_buffer); + + forward_storage.removeTensorFromTensorIndexToData(runtime_kernel.inputs_index[stateTensorIdx]); + + return status; +} diff --git a/onert-micro/onert-micro/src/train/kernels/StridedSlice.cpp b/onert-micro/onert-micro/src/train/kernels/StridedSlice.cpp new file mode 100644 index 00000000000..b2a8ff4fa14 --- /dev/null +++ b/onert-micro/onert-micro/src/train/kernels/StridedSlice.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMDataType.h" + +#include "train/OMBackpropExecutionBuilder.h" +#include "execute/OMRuntimeKernel.h" + +using namespace onert_micro; +using namespace onert_micro::train; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +/* + * - Calculate input gradient - Optional (not required if it is last op) + * Note: now support when it is just reshape, number of output tensor is one and flat size of the + * output tensor is equal to input + */ +// TODO: support more general part +OMStatus onert_micro::train::train_kernel_CircleStridedSlice(const OMBackpropExecuteArgs &args) +{ + // Check is it last layer for training + if (args.is_last_layer) + return Ok; + + core::OMRuntimeContext &runtime_context = args.backward_context; + core::OMRuntimeStorage &runtime_storage = args.backward_storage; + uint16_t op_index = args.kernel_index; + + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx]; + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input != nullptr); + assert(output != nullptr); + + // Note: now support when it is just reshape, number of output tensor is one and flat size of the + // output tensor is equal to input + assert(runtime_kernel.outputs_num == 1); + const core::OMRuntimeShape shape(input); + const core::OMRuntimeShape output_shape(input); + assert(shape.flatSize() == output_shape.flatSize()); + if (runtime_kernel.outputs_num > 1 or shape.flatSize() != output_shape.flatSize()) + return UnsupportedType; + + OMStatus status = Ok; + + status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + if (status != Ok) + return status; + + uint8_t *input_data = runtime_kernel.inputs_data[inputTensorIdx]; + uint8_t *output_data = runtime_kernel.outputs_data[outputTensorIdx]; + + assert(input_data != nullptr); + assert(output_data != nullptr); + + // Check is it inplace kernel + if (input_data == output_data) + return Ok; + + const size_t element_size = + static_cast(getOMDataTypeSize(core::onertMicroDatatype(input->type()))); + const int32_t num_elements = shape.flatSize(); + std::memcpy(input_data, output_data, num_elements * element_size); + + return status; +}