diff --git a/onert-micro/onert-micro/include/pal/common/PALLogistic.h b/onert-micro/onert-micro/include/pal/common/PALLogistic.h new file mode 100644 index 00000000000..de0f4f94fdd --- /dev/null +++ b/onert-micro/onert-micro/include/pal/common/PALLogistic.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_EXECUTE_PAL_LOGISTIC_H +#define ONERT_MICRO_EXECUTE_PAL_LOGISTIC_H + +#include "PALUtils.h" + +#include + +namespace onert_micro +{ +namespace execute +{ +namespace pal +{ + +OMStatus inline Logistic(const int flat_size, const float *input_data, float *output_data) +{ + const float cutoff_upper = 16.619047164916992188f; + const float cutoff_lower = -9.f; + + // Rational for using approximation in reference kernel. + // 0. This approximation gives enough precision for float. + // 1. This works around an issue on an embedded chipset where exp() does not + // return correctly as expected - exp(x) should return inf when overflown + // not 1.701417 IEEE 754 defines representation for inf. + // 2. This will speed up calculation and is matching the behavior in the + // optimized kernels. (check the definition of scalar_logistic_op) + + for (int i = 0; i < flat_size; i++) + { + float val = input_data[i]; + float result; + if (val > cutoff_upper) + { + result = 1.0f; + } + else if (val < cutoff_lower) + { + result = std::exp(val); + } + else + { + result = 1.f / (1.f + std::exp(-val)); + } + output_data[i] = result; + } + return Ok; +} + +OMStatus inline Logistic(const int flat_size, const int8_t *input_data, float input_scale, + int input_zero_point, int8_t *output_data, float output_scale, + int output_zero_point) +{ + const float cutoff_upper = 16.619047164916992188f; + const float cutoff_lower = -9.f; + + // Rational for using approximation in reference kernel. + // 0. This approximation gives enough precision for float. + // 1. This works around an issue on an embedded chipset where exp() does not + // return correctly as expected - exp(x) should return inf when overflown + // not 1.701417 IEEE 754 defines representation for inf. + // 2. This will speed up calculation and is matching the behavior in the + // optimized kernels. (check the definition of scalar_logistic_op) + + for (int i = 0; i < flat_size; i++) + { + // Dequantize. + float val = static_cast((input_data[i] - input_zero_point) * input_scale); + float result; + if (val > cutoff_upper) + { + result = 1.0f; + } + else if (val < cutoff_lower) + { + result = std::exp(val); + } + else + { + result = 1.f / (1.f + std::exp(-val)); + } + // Requantize + int8_t output = static_cast(result / output_scale + output_zero_point); + output_data[i] = output; + } + return Ok; +} + +} // namespace pal +} // namespace execute +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_PAL_LOGISTIC_H diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst index 2573c352c40..5bc3135af80 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst @@ -1,3 +1,4 @@ REGISTER_KERNEL(ABS, Abs) REGISTER_KERNEL(ADD, Add) +REGISTER_KERNEL(LOGISTIC, Logistic) REGISTER_KERNEL(CONCATENATION, Concatenation) diff --git a/onert-micro/onert-micro/include/test_models/logistic/FloatLogisticKernel.h b/onert-micro/onert-micro/include/test_models/logistic/FloatLogisticKernel.h new file mode 100644 index 00000000000..88a1f124caa --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/logistic/FloatLogisticKernel.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_FLOAT_LOGISTIC_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_FLOAT_LOGISTIC_KERNEL_H + +#include "TestDataLogisticBase.h" + +namespace onert_micro +{ +namespace test_model +{ +namespace logistic_float +{ +/* + * Logistic Kernel: + * + * Input(1, 3, 3, 2) + * | + * Logistic + * | + * Output(1, 3, 3, 2) + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x14, 0x01, 0x00, 0x00, 0x30, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff, + 0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xd4, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, + 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; + +const std::vector input_data = {29.353455, 12.060211, 11.372606, -9.009369, 3.0267563, + 5.1447716, 21.289762, 19.976126, 8.726238, 4.8797092, + 3.64571, 34.80062, -6.9072685, -2.2714958, -16.44065, + 0.334301, -20.372694, 4.1522675}; + +const std::vector reference_output_data = { + 1.0, 0.99999416, 0.99998844, 0.00012225899, 0.9537683, 0.994204, + 1.0, 1.0, 0.99983776, 0.9924581, 0.97456115, 1.0, + 0.0009994869, 0.093511336, 7.2429586e-08, 0.5828055, 1.4198792e-09, 0.98451483}; + +} // namespace logistic_float + +class TestDataFloatLogistic : public TestDataLogisticBase +{ +public: + TestDataFloatLogistic() + { + _input_data = logistic_float::input_data; + _reference_output_data = logistic_float::reference_output_data; + _test_kernel_model_circle = logistic_float::test_kernel_model_circle; + } + + ~TestDataFloatLogistic() override = default; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_FLOAT_LOGISTIC_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/logistic/NegLogisticKernel.h b/onert-micro/onert-micro/include/test_models/logistic/NegLogisticKernel.h new file mode 100644 index 00000000000..430c9b37fc2 --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/logistic/NegLogisticKernel.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_NEG_LOGISTIC_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_NEG_LOGISTIC_KERNEL_H + +#include "TestDataLogisticBase.h" + +namespace onert_micro +{ +namespace test_model +{ +namespace neg_logistic_input_output_type_mismatch +{ +/* + * Logistic Kernel with input output type mismatch (should be equal): + * + * Input(1, 3, 3, 2) - Float + * | + * Logistic + * | + * Output(1, 3, 3, 2) - Int + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x24, 0x01, 0x00, 0x00, 0x40, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff, + 0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, + 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; + +} // namespace neg_logistic_input_output_type_mismatch + +namespace neg_logistic_no_quant_params +{ +/* + * Logistic Kernel with UINT8 type and without quant params: + * + * Input(1, 3, 3, 2) - UINT8 + * | + * Logistic (no quant params) + * | + * Output(1, 3, 3, 2) - UINT8 + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x1c, 0x01, 0x00, 0x00, 0x38, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff, + 0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xd0, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x03, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, + 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, + 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; + +} // namespace neg_logistic_no_quant_params + +class NegTestDataInputOutputTypeMismatchLogisticKernel : public NegTestDataBase +{ +public: + NegTestDataInputOutputTypeMismatchLogisticKernel() + { + _test_kernel_model_circle = neg_logistic_input_output_type_mismatch::test_kernel_model_circle; + } + + ~NegTestDataInputOutputTypeMismatchLogisticKernel() override = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + +protected: + const unsigned char *_test_kernel_model_circle; +}; + +class NegTestDataNoQuantParamsLogisticKernel : public NegTestDataBase +{ +public: + NegTestDataNoQuantParamsLogisticKernel() + { + _test_kernel_model_circle = neg_logistic_no_quant_params::test_kernel_model_circle; + } + + ~NegTestDataNoQuantParamsLogisticKernel() override = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + +protected: + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_NEG_LOGISTIC_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/logistic/TestDataLogisticBase.h b/onert-micro/onert-micro/include/test_models/logistic/TestDataLogisticBase.h new file mode 100644 index 00000000000..d972ff86364 --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/logistic/TestDataLogisticBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_LOGISTIC_KERNEL_BASE_H +#define ONERT_MICRO_TEST_MODELS_LOGISTIC_KERNEL_BASE_H + +#include "test_models/TestDataBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +template class TestDataLogisticBase : public TestDataBase +{ +public: + TestDataLogisticBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_LOGISTIC_KERNEL_BASE_H diff --git a/onert-micro/onert-micro/src/execute/kernels/Logistic.cpp b/onert-micro/onert-micro/src/execute/kernels/Logistic.cpp new file mode 100644 index 00000000000..adf05d2b219 --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/Logistic.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "core/OMUtils.h" + +#include "execute/OMKernelExecutionBuilder.h" +#include "execute/OMRuntimeKernel.h" + +#include "PALLogistic.h" + +using namespace onert_micro; +using namespace onert_micro::execute; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +// NOTE: doesnt currently support dynamic shapes +OMStatus onert_micro::execute::execute_kernel_CircleLogistic(const OMExecuteArgs &execute_args) +{ + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + + const circle::Tensor *input = nullptr; + const circle::Tensor *output = nullptr; + + uint8_t *input_data = nullptr; + uint8_t *output_data = nullptr; + + OMStatus status = Ok; + + { + OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[inputTensorIdx]; + output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input != nullptr); + assert(output != nullptr); + + status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + if (status != Ok) + return status; + + input_data = runtime_kernel.inputs_data[inputTensorIdx]; + output_data = runtime_kernel.outputs_data[outputTensorIdx]; + } + + assert(input_data != nullptr); + assert(output_data != nullptr); + + switch (input->type()) + { +#ifndef DIS_FLOAT + case circle::TensorType_FLOAT32: + { + status = pal::Logistic(core::OMRuntimeShape(input).flatSize(), + core::utils::castInputData(input_data), + core::utils::castOutputData(output_data)); + } + break; +#endif // DIS_FLOAT +#ifndef DIS_QUANT + case circle::TensorType_INT8: + { + assert(input->quantization() != nullptr); + assert(input->quantization()->scale() != nullptr); + assert(input->quantization()->scale()->size() == 1); + assert(input->quantization()->zero_point() != nullptr); + assert(input->quantization()->zero_point()->size() == 1); + + assert(output->quantization() != nullptr); + assert(output->quantization()->scale() != nullptr); + assert(output->quantization()->scale()->size() == 1); + assert(output->quantization()->zero_point() != nullptr); + assert(output->quantization()->zero_point()->size() == 1); + + auto input_scale = *input->quantization()->scale()->begin(); + auto input_zero_point = *input->quantization()->zero_point()->begin(); + auto output_scale = *input->quantization()->scale()->begin(); + auto output_zero_point = *input->quantization()->zero_point()->begin(); + + status = pal::Logistic(core::OMRuntimeShape(input).flatSize(), + core::utils::castInputData(input_data), input_scale, + input_zero_point, core::utils::castOutputData(output_data), + output_scale, output_zero_point); + } + break; +#endif // DIS_QUANT + default: + { + status = UnsupportedType; + assert(false && "Unsupported type."); + } + } + + return status; +} diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/Logistic.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/Logistic.test.cpp new file mode 100644 index 00000000000..9b82cadae50 --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/tests/Logistic.test.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "execute/OMTestUtils.h" +#include "test_models/logistic/FloatLogisticKernel.h" +#include "test_models/logistic/NegLogisticKernel.h" + +namespace onert_micro +{ +namespace execute +{ +namespace testing +{ + +using namespace testing; + +class LogisticTest : public ::testing::Test +{ + // Do nothing +}; + +TEST_F(LogisticTest, Float_P) +{ + onert_micro::test_model::TestDataFloatLogistic test_data_kernel; + std::vector output_data_vector = + onert_micro::execute::testing::checkKernel(1, &test_data_kernel); + EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0)); +} + +TEST_F(LogisticTest, Input_output_type_mismatch_NEG) +{ + onert_micro::test_model::NegTestDataInputOutputTypeMismatchLogisticKernel test_data_kernel; + + EXPECT_DEATH(checkNEGSISOKernel(&test_data_kernel), ""); +} + +TEST_F(LogisticTest, No_quant_params_NEG) +{ + onert_micro::test_model::NegTestDataNoQuantParamsLogisticKernel test_data_kernel; + + EXPECT_DEATH(checkNEGSISOKernel(&test_data_kernel), ""); +} + +// TODO: Add S8 test + +} // namespace testing +} // namespace execute +} // namespace onert_micro diff --git a/onert-micro/onert-micro/src/import/kernels/Logistic.cpp b/onert-micro/onert-micro/src/import/kernels/Logistic.cpp new file mode 100644 index 00000000000..5392a41bc83 --- /dev/null +++ b/onert-micro/onert-micro/src/import/kernels/Logistic.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "import/OMKernelConfigureBuilder.h" +#include "core/OMUtils.h" +#include "execute/OMRuntimeKernel.h" + +using namespace onert_micro; +using namespace onert_micro::core; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus onert_micro::import::configure_kernel_CircleLogistic(const OMConfigureArgs &config_args) +{ + OMRuntimeContext &runtime_context = config_args.runtime_context; + uint16_t op_index = config_args.kernel_index; + + onert_micro::execute::OMRuntimeKernel runtime_kernel; + + OMStatus status = runtime_kernel.readKernel(op_index, runtime_context); + if (status != Ok) + return status; + + const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx]; + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input != nullptr); + assert(output != nullptr); + + status = utils::checkCondition(input->type() == output->type()); + if (status != Ok) + return status; + + OMRuntimeShape input_shape(input); + OMRuntimeShape output_shape(output); + + status = utils::checkCondition(input_shape.dimensionsCount() == output_shape.dimensionsCount()); + if (status != Ok) + return status; + + status = utils::checkCondition(input_shape.flatSize() == output_shape.flatSize()); + if (status != Ok) + return status; + + if (input->type() != circle::TensorType_INT8 and input->type() != circle::TensorType_INT16) + return status; + + // Check quantized version + if (input->quantization() == nullptr or output->quantization() == nullptr) + return NoQuantization; + + if (output->quantization()->scale() == nullptr or output->quantization()->scale()->size() != 1) + return UnsupportedQuantizationType; + + if (input->quantization()->scale() == nullptr or input->quantization()->scale()->size() != 1) + return UnsupportedQuantizationType; + + return status; +}