From a2b956104a93f11ce7f4fd96a2cd2047a43ce877 Mon Sep 17 00:00:00 2001 From: SeungHui Youn <61981457+zetwhite@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:22:59 +0900 Subject: [PATCH] [onert] Introduce backpropActivation to OperationUtils This PR introduces backpropActivation function to OperationUtils. This function is to call proper cker according to ir::Activation. ONE-DCO-1.0-Signed-off-by: SeungHui Youn --- .../onert/backend/train/ops/OperationUtils.cc | 66 +++++++++++++++++++ .../onert/backend/train/ops/OperationUtils.h | 25 +++++++ 2 files changed, 91 insertions(+) create mode 100644 runtime/onert/backend/train/ops/OperationUtils.cc diff --git a/runtime/onert/backend/train/ops/OperationUtils.cc b/runtime/onert/backend/train/ops/OperationUtils.cc new file mode 100644 index 00000000000..086b57d1ae5 --- /dev/null +++ b/runtime/onert/backend/train/ops/OperationUtils.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OperationUtils.h" + +#include +#include + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +const IPortableTensor *backpropActivation(const ir::Activation &activation, + const IPortableTensor *output, + const IPortableTensor *input_backprop, + IPortableTensor *output_backprop) +{ + // handle NONE + if (activation == ir::Activation::NONE) + { + // just propagate incoming gradient + return input_backprop; + } + + // handle other activation + assert(output_backprop != nullptr); + switch (activation) + { + case ir::Activation::RELU: + nnfw::cker::train::ReLUGrad(getShape(output), getBuffer(output), + getShape(input_backprop), getBuffer(input_backprop), + getShape(output_backprop), getBuffer(output_backprop)); + break; + case ir::Activation::RELU6: + nnfw::cker::train::ReLU6Grad(getShape(output), getBuffer(output), + getShape(input_backprop), getBuffer(input_backprop), + getShape(output_backprop), getBuffer(output_backprop)); + break; + default: + throw std::runtime_error("Unsupported activation type yet"); + } + return output_backprop; +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert diff --git a/runtime/onert/backend/train/ops/OperationUtils.h b/runtime/onert/backend/train/ops/OperationUtils.h index 470b0fd91c9..2866a1c4b98 100644 --- a/runtime/onert/backend/train/ops/OperationUtils.h +++ b/runtime/onert/backend/train/ops/OperationUtils.h @@ -36,6 +36,31 @@ using cpu::ops::getNumberOfDimensions; using cpu::ops::getNumberOfElements; using cpu::ops::getSizeOfDimension; +/** + * @brief backpropagate acitvation + * + * -- forward direction --> + * + * [ current layer ] ---- [ next layer ] + * [ op | act ] + * + * <-- backward direction -- + * + * @param activation activation of current layer + * @param output forward direction's output of current layer + * @param input_backprop backward direction's output of next layer + * In other words, incoming gradient to current layer + * @param output_backprop backward direction's output of activation, + * In other words, outcoming gradient of current layer's acitvation + * If activation is NONE, this param isn't necessary + * @return tensor that holds backpropagation result of activation + * If activation is NONE, just return input_backprop + */ +const IPortableTensor *backpropActivation(const ir::Activation &activation, + const IPortableTensor *output, + const IPortableTensor *input_backprop, + IPortableTensor *output_backprop = nullptr); + } // namespace ops } // namespace train } // namespace backend