diff --git a/compute/cker/include/cker/train/operation/ReLU6.h b/compute/cker/include/cker/train/operation/ReLU6.h new file mode 100644 index 00000000000..3016a031a42 --- /dev/null +++ b/compute/cker/include/cker/train/operation/ReLU6.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_CKER_TRAIN_OPERATION_RELU6_H__ +#define __NNFW_CKER_TRAIN_OPERATION_RELU6_H__ + +#include "cker/Shape.h" +#include "cker/eigen/Utils.h" +#include + +namespace nnfw +{ +namespace cker +{ +namespace train +{ + +inline void ReLU6Grad(const Shape &output_shape, const float *output_data, + const Shape &incoming_shape, const float *incoming_data, + const Shape &grad_shape, float *grad_data) +{ + const auto output_map = MapAsVector(output_data, output_shape); + const auto incoming_map = MapAsVector(incoming_data, incoming_shape); + auto grad_map = MapAsVector(grad_data, grad_shape); + + if (output_shape != incoming_shape || output_shape != grad_shape) + throw std::runtime_error{"cker::ReLU6Grad: Unsupported shape"}; + + grad_map.array() = + incoming_map.array() * + (0.0f < output_map.array() && output_map.array() < 6.0f).template cast(); +} + +} // namespace train +} // namespace cker +} // namespace nnfw + +#endif // __NNFW_CKER_TRAIN_OPERATION_RELU6_H__ diff --git a/compute/cker/src/train/Relu6.test.cc b/compute/cker/src/train/Relu6.test.cc new file mode 100644 index 00000000000..ca8a3fc226e --- /dev/null +++ b/compute/cker/src/train/Relu6.test.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include + +namespace +{ + +using namespace nnfw::cker; + +template class Relu6OpVerifier +{ +public: + void verifyBackward(const std::vector &output, const std::vector &input_bwd, + const std::vector &expected_output_bwd, bool expect_eq = true) + { + std::vector calc_output_bwd(input_bwd.size()); // calculated output backward + train::ReLU6Grad(Shape{static_cast(output.size())}, output.data(), + Shape{static_cast(input_bwd.size())}, input_bwd.data(), + Shape{static_cast(calc_output_bwd.size())}, calc_output_bwd.data()); + + if (expect_eq) + EXPECT_EQ(expected_output_bwd, calc_output_bwd); + else + EXPECT_NE(expected_output_bwd, calc_output_bwd); + } +}; + +} // namespace + +TEST(CKer_Operation, ReLU6) +{ + { + Relu6OpVerifier verifier; + + // clang-format off + // std::vector input_fwd = {-2.0, -1.0, 2.0, 3.0, 6.0, 7.0}; + std::vector output_fwd = { 0.0, 0.0, 2.0, 3.0, 6.0, 6.0}; + std::vector input_bwd = {-0.1, -0.2, 0.3, 0.4, -0.1, 0.5}; + std::vector expected_output_bwd = { 0.0, 0.0, 0.3, 0.4, 0.0, 0.0}; + // clang-format on + + verifier.verifyBackward(output_fwd, input_bwd, expected_output_bwd); + } + + { + Relu6OpVerifier verifier; + + // clang-format off + // std::vector input_fwd = { 7.0, 8.0, 4.0, -4.0, -5.0, 10.0}; + std::vector output_fwd = { 6.0, 6.0, 4.0, 0.0, 0.0, 6.0}; + std::vector input_bwd = {-6.1, -3.3, 7.0, 8.4, -9.2, 0.0}; + std::vector expected_output_bwd = { 0.0, 0.0, 7.0, 0.0, 0.0, 0.0}; + // clang-format on + + verifier.verifyBackward(output_fwd, input_bwd, expected_output_bwd); + } +} + +TEST(CKer_Operation, neg_ReLU6) +{ + { + Relu6OpVerifier verifier; + + // clang-format off + // std::vector input_fwd = { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0}; + std::vector output_fwd = { 0.0, 2.0, 4.0, 6.0, 6.0, 6.0}; + std::vector input_bwd = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector expected_output_bwd = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; // wrong value + // clang-format on + + verifier.verifyBackward(output_fwd, input_bwd, expected_output_bwd, false); + } + + { + Relu6OpVerifier verifier; + + // clang-format off + // std::vector input_fwd = { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0}; + std::vector output_fwd = { 0.0, 2.0, 4.0, 6.0, 6.0, 6.0}; + std::vector input_bwd = { 0.1, 0.2, 0.3, 0.4}; // size mismatch + std::vector expected_output_bwd = { 0.0, 0.2, 0.3, 0.4}; + // clang-format on + + EXPECT_ANY_THROW(verifier.verifyBackward(output_fwd, input_bwd, expected_output_bwd, false)); + } +}