diff --git a/compute/cker/include/cker/train/operation/ReLU6.h b/compute/cker/include/cker/train/operation/ReLU6.h index c615a6d69cb..a66b0c4ff97 100644 --- a/compute/cker/include/cker/train/operation/ReLU6.h +++ b/compute/cker/include/cker/train/operation/ReLU6.h @@ -32,16 +32,16 @@ inline void ReLU6Grad(const Shape &output_shape, const float *output_data, const Shape &incoming_shape, const float *incoming_data, const Shape &grad_shape, float *grad_data) { - const auto output_map = MapAsVector(output_data, output_shape); - const auto incoming_map = MapAsVector(incoming_data, incoming_shape); - auto grad_map = MapAsVector(grad_data, grad_shape); - - if (output_shape == incoming_shape && output_shape == grad_shape) - grad_map.array() = - incoming_map.array() * - (0.0f < output_map.array() && output_map.array() < 6.0f).template cast(); - else - throw std::runtime_error("cker::ReLU6Grad: Unsupported shape"); + const auto &output_map = MapAsVector(output_data, output_shape); + const auto &incoming_map = MapAsVector(incoming_data, incoming_shape); + auto &&grad_map = MapAsVector(grad_data, grad_shape); + + if (output_shape != incoming_shape || output_shape != grad_shape) + throw std::runtime_error{"cker::ReLU6Grad: Unsupported shape"}; + + grad_map.array() = + incoming_map.array() * + (0.0f < output_map.array() && output_map.array() < 6.0f).template cast(); } } // namespace train diff --git a/compute/cker/src/train/Relu6.test.cc b/compute/cker/src/train/Relu6.test.cc index 3eb9f042f7b..ca8a3fc226e 100644 --- a/compute/cker/src/train/Relu6.test.cc +++ b/compute/cker/src/train/Relu6.test.cc @@ -53,7 +53,7 @@ TEST(CKer_Operation, ReLU6) // clang-format off // std::vector input_fwd = {-2.0, -1.0, 2.0, 3.0, 6.0, 7.0}; - std::vector output_fwd = { 0.0, 0.0, 2.0, 3.0, 6.0, 7.0}; + std::vector output_fwd = { 0.0, 0.0, 2.0, 3.0, 6.0, 6.0}; std::vector input_bwd = {-0.1, -0.2, 0.3, 0.4, -0.1, 0.5}; std::vector expected_output_bwd = { 0.0, 0.0, 0.3, 0.4, 0.0, 0.0}; // clang-format on