Skip to content

Commit

Permalink
[onert] Apply ReLU6Grad to ElementwiseActivationLayer
Browse files Browse the repository at this point in the history
This PR applies ReLU6Grad to ElementwiseActivationLayer.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn <[email protected]>
  • Loading branch information
zetwhite committed Jan 17, 2024
1 parent 2a4d8d6 commit e74df47
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "OperationUtils.h"

#include <cker/train/operation/ReLU.h>
#include <cker/train/operation/ReLU6.h>

namespace onert
{
Expand Down Expand Up @@ -54,22 +55,31 @@ void ElementwiseActivationLayer::configure(const IPortableTensor *input, IPortab
case ElementwiseActivationType::kReLU:
if (input->data_type() == OperandType::FLOAT32)
{
if (alpha == std::numeric_limits<float>::infinity() && beta == 0.f)
if ((alpha == std::numeric_limits<float>::infinity() || alpha == 6.0f) && beta == 0.f)
{
cpu::ops::ElementwiseActivationLayer::configure(
input, output, alpha, beta, cpu::ops::ElementwiseActivationType::kReLU);

_backward_kernel = [](const IPortableTensor *output, const IPortableTensor *incoming,
IPortableTensor *outgoing) {
nnfw::cker::train::ReLUGrad(getShape(output), getBuffer<float>(output),
getShape(incoming), getBuffer<float>(incoming),
getShape(outgoing), getBuffer<float>(outgoing));
auto relu_cker = [&alpha]() {
if (alpha == std::numeric_limits<float>::infinity())
return nnfw::cker::train::ReLUGrad;
else if (alpha == 6.0f)
return nnfw::cker::train::ReLU6Grad;
else
throw std::runtime_error{"no supported relu kernel"};
}();

_backward_kernel = [relu_cker](const IPortableTensor *output,
const IPortableTensor *incoming,
IPortableTensor *outgoing) {
relu_cker(getShape(output), getBuffer<float>(output), getShape(incoming),
getBuffer<float>(incoming), getShape(outgoing), getBuffer<float>(outgoing));
};
}
else
{
throw std::runtime_error("train ElementwiseActivationLayer : This layer does not "
"suppport other ReLU except for ReLU(0-inf)");
"suppport other ReLU except for ReLU(0-inf) and ReLU6(0-6)");
}
}
else
Expand Down

0 comments on commit e74df47

Please sign in to comment.