diff --git a/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc index 5015e6959b5..bd4734f7ae2 100644 --- a/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc +++ b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc @@ -78,8 +78,8 @@ void ElementwiseActivationLayer::configure(const IPortableTensor *input, IPortab } else { - throw std::runtime_error("train ElementwiseActivationLayer : This layer does not " - "suppport other ReLU except for ReLU(0-inf) and ReLU6(0-6)"); + throw std::runtime_error( + "train ElementwiseActivationLayer : Unsupported ReLU activation type"); } } else diff --git a/runtime/onert/backend/train/ops/OperationUtils.cc b/runtime/onert/backend/train/ops/OperationUtils.cc new file mode 100644 index 00000000000..47335974615 --- /dev/null +++ b/runtime/onert/backend/train/ops/OperationUtils.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OperationUtils.h" + +#include +#include + +namespace onert +{ +namespace backend +{ +namespace train +{ +namespace ops +{ + +const IPortableTensor *backpropActivation(const ir::Activation &activation, + const IPortableTensor *output, + const IPortableTensor *input_backprop, + IPortableTensor *output_backprop) +{ + if(activation == ir::Activation::NONE) + assert(output_backprop != nullptr); + + const IPortableTensor *res; + switch (activation) + { + case ir::Activation::NONE: + res = input_backprop; + break; + case ir::Activation::RELU: + nnfw::cker::train::ReLUGrad(getShape(output), getBuffer(output), + getShape(input_backprop), getBuffer(input_backprop), + getShape(output_backprop), getBuffer(output_backprop)); + res = output_backprop; + break; + case ir::Activation::RELU6: + nnfw::cker::train::ReLU6Grad(getShape(output), getBuffer(output), + getShape(input_backprop), getBuffer(input_backprop), + getShape(output_backprop), getBuffer(output_backprop)); + res = output_backprop; + break; + default: + throw std::runtime_error("Unsupported activation type yet"); + } + + return res; +} + +} // namespace ops +} // namespace train +} // namespace backend +} // namespace onert + + \ No newline at end of file