diff --git a/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc index a7a4d412576..5015e6959b5 100644 --- a/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc +++ b/runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc @@ -19,6 +19,7 @@ #include "OperationUtils.h" #include +#include namespace onert { @@ -54,22 +55,31 @@ void ElementwiseActivationLayer::configure(const IPortableTensor *input, IPortab case ElementwiseActivationType::kReLU: if (input->data_type() == OperandType::FLOAT32) { - if (alpha == std::numeric_limits::infinity() && beta == 0.f) + if ((alpha == std::numeric_limits::infinity() || alpha == 6.0f) && beta == 0.f) { cpu::ops::ElementwiseActivationLayer::configure( input, output, alpha, beta, cpu::ops::ElementwiseActivationType::kReLU); - _backward_kernel = [](const IPortableTensor *output, const IPortableTensor *incoming, - IPortableTensor *outgoing) { - nnfw::cker::train::ReLUGrad(getShape(output), getBuffer(output), - getShape(incoming), getBuffer(incoming), - getShape(outgoing), getBuffer(outgoing)); + auto relu_cker = [&alpha]() { + if (alpha == std::numeric_limits::infinity()) + return nnfw::cker::train::ReLUGrad; + else if (alpha == 6.0f) + return nnfw::cker::train::ReLU6Grad; + else + throw std::runtime_error{"no supported relu kernel"}; + }(); + + _backward_kernel = [relu_cker](const IPortableTensor *output, + const IPortableTensor *incoming, + IPortableTensor *outgoing) { + relu_cker(getShape(output), getBuffer(output), getShape(incoming), + getBuffer(incoming), getShape(outgoing), getBuffer(outgoing)); }; } else { throw std::runtime_error("train ElementwiseActivationLayer : This layer does not " - "suppport other ReLU except for ReLU(0-inf)"); + "suppport other ReLU except for ReLU(0-inf) and ReLU6(0-6)"); } } else