From b7d96143d93f3356959d1b5638a8f02c3dd80a81 Mon Sep 17 00:00:00 2001 From: sseung Date: Tue, 16 Jan 2024 16:54:56 +0900 Subject: [PATCH] apply code review --- compute/cker/include/cker/train/operation/ReLU6.h | 12 ++++++------ compute/cker/src/train/Relu6.test.cc | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/compute/cker/include/cker/train/operation/ReLU6.h b/compute/cker/include/cker/train/operation/ReLU6.h index c615a6d69cb..3016a031a42 100644 --- a/compute/cker/include/cker/train/operation/ReLU6.h +++ b/compute/cker/include/cker/train/operation/ReLU6.h @@ -36,12 +36,12 @@ inline void ReLU6Grad(const Shape &output_shape, const float *output_data, const auto incoming_map = MapAsVector(incoming_data, incoming_shape); auto grad_map = MapAsVector(grad_data, grad_shape); - if (output_shape == incoming_shape && output_shape == grad_shape) - grad_map.array() = - incoming_map.array() * - (0.0f < output_map.array() && output_map.array() < 6.0f).template cast(); - else - throw std::runtime_error("cker::ReLU6Grad: Unsupported shape"); + if (output_shape != incoming_shape || output_shape != grad_shape) + throw std::runtime_error{"cker::ReLU6Grad: Unsupported shape"}; + + grad_map.array() = + incoming_map.array() * + (0.0f < output_map.array() && output_map.array() < 6.0f).template cast(); } } // namespace train diff --git a/compute/cker/src/train/Relu6.test.cc b/compute/cker/src/train/Relu6.test.cc index 3eb9f042f7b..ca8a3fc226e 100644 --- a/compute/cker/src/train/Relu6.test.cc +++ b/compute/cker/src/train/Relu6.test.cc @@ -53,7 +53,7 @@ TEST(CKer_Operation, ReLU6) // clang-format off // std::vector input_fwd = {-2.0, -1.0, 2.0, 3.0, 6.0, 7.0}; - std::vector output_fwd = { 0.0, 0.0, 2.0, 3.0, 6.0, 7.0}; + std::vector output_fwd = { 0.0, 0.0, 2.0, 3.0, 6.0, 6.0}; std::vector input_bwd = {-0.1, -0.2, 0.3, 0.4, -0.1, 0.5}; std::vector expected_output_bwd = { 0.0, 0.0, 0.3, 0.4, 0.0, 0.0}; // clang-format on