diff --git a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst index bf9154f126b..0072ced1346 100644 --- a/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst @@ -53,7 +53,7 @@ REGISTER_KERNEL(LOG_SOFTMAX, LogSoftmax) REGISTER_KERNEL(MUL, Mul) #/*REGISTER_KERNEL(MIRROR_PAD, MirrorPad)*/ REGISTER_KERNEL(MAXIMUM, Maximum) -#/*REGISTER_KERNEL(MEAN, Mean)*/ +REGISTER_KERNEL(MEAN, Mean) REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D) REGISTER_KERNEL(MINIMUM, Minimum) REGISTER_KERNEL(SHAPE, Shape) diff --git a/onert-micro/onert-micro/include/pal/mcu/PALReduceCommon.h b/onert-micro/onert-micro/include/pal/mcu/PALReduceCommon.h index 75e48b78ac5..05098e41eea 100644 --- a/onert-micro/onert-micro/include/pal/mcu/PALReduceCommon.h +++ b/onert-micro/onert-micro/include/pal/mcu/PALReduceCommon.h @@ -29,10 +29,9 @@ namespace pal // This method parses the input 'axis' to remove duplicates and handle negative // values, and returns a valid 'out_axis' -inline bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis, +inline bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis, int *out_axis, int *out_num_axis) { - int out_axis[2]; *out_num_axis = 0; // Just in case. // Short-circuit axis resolution for scalars; the axis will go unused. if (num_dims == 0) @@ -75,7 +74,7 @@ inline bool resolveAxis(const int num_dims, const int *axis, const int64_t num_a // Computes the generic value (i.e., sum/max/min/prod) of elements across // dimensions given in axis. It needs to pass in init_value and reducer. template -inline void ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims, +inline bool ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *axis, const int64_t num_axis_dimensions, T init_value, const int output_flat_size, T reducer(const T, const T)) { @@ -83,7 +82,7 @@ inline void ReduceGeneric(const T *input_data, const int *input_dims, const int for (int i = 0; i < input_num_dims; ++i) { if (input_dims[i] == 0) - return; + return false; } for (size_t idx = 0; idx < output_flat_size; ++idx) @@ -93,9 +92,11 @@ inline void ReduceGeneric(const T *input_data, const int *input_dims, const int // Resolve axis. int num_resolved_axis = 0; - if (!resolveAxis(input_num_dims, axis, num_axis_dimensions, &num_resolved_axis)) + int resolved_axis[2]; + + if (!resolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, &num_resolved_axis)) { - return; + return false; } int temp_index[5]; @@ -112,6 +113,62 @@ inline void ReduceGeneric(const T *input_data, const int *input_dims, const int reducedOutputOffset(input_num_dims, input_dims, temp_index, num_resolved_axis, axis); output_data[output_offset] = reducer(output_data[output_offset], input_data[input_offset]); } while (nextIndex(input_num_dims, input_dims, temp_index)); + + return true; +} + +// This method expects that output_data has been initialized. +template +inline bool reduceSumImpl(const T *input_data, const int *input_dims, const int input_num_dims, + T *output_data, const int *axis, const int num_axis, + const int num_outputs) +{ + return ReduceGeneric(input_data, input_dims, input_num_dims, output_data, axis, num_axis, + static_cast(0), num_outputs, + [](const T current, const T in) -> T { return in + current; }); +} + +template +inline bool Mean(const int *input_dims, const T *input_data, const int input_num_dims, + T *output_data, const int num_outputs, const int *axis, + const int num_axis_dimensions) +{ + if (!reduceSumImpl(input_data, input_dims, input_num_dims, output_data, axis, + num_axis_dimensions, num_outputs)) + { + return false; + } + + // Resolve axis again for computing mean + int num_resolved_axis = 0; + int resolved_axis[2]; + + if (!resolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, &num_resolved_axis)) + { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_resolved_axis; ++idx) + { + size_t current = static_cast(input_dims[resolved_axis[idx]]); + // Overflow prevention. + if (current > (std::numeric_limits::max() / num_elements_in_axis)) + { + return false; + } + num_elements_in_axis *= current; + } + + if (num_elements_in_axis > 0) + { + for (size_t idx = 0; idx < num_outputs; ++idx) + { + output_data[idx] = static_cast(output_data[idx] / static_cast(num_elements_in_axis)); + } + } + return true; } } // namespace pal diff --git a/onert-micro/onert-micro/include/test_models/mean/FloatMeanKernel.h b/onert-micro/onert-micro/include/test_models/mean/FloatMeanKernel.h new file mode 100644 index 00000000000..b09d1463d75 --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/mean/FloatMeanKernel.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_FLOAT_MEAN_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_FLOAT_MEAN_KERNEL_H + +#include "TestDataMeanBase.h" + +namespace onert_micro +{ +namespace test_model +{ +namespace mean_float +{ +/* + * Mean Kernel: + * + * Input(1, 8, 8, 4) + * | + * Mean + * | + * Output(1, 8, 8, 1) + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x94, 0x01, 0x00, 0x00, 0xb0, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff, 0xfc, 0xff, 0xff, 0xff, + 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x6c, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x1b, 0x14, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x94, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x1c, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x72, 0x65, 0x64, 0x75, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e, + 0x64, 0x69, 0x63, 0x65, 0x73, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x28, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, + 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; + +const std::vector input_data = { + -73.19745, -62.66789, -15.868883, -69.99245, -86.77558, -47.07158, -59.42521, 5.4639907, + -15.482954, 58.430527, 30.962307, -8.479264, 64.87171, 67.23879, 54.92413, -75.001656, + 4.095402, -11.012883, 1.7135352, -13.673498, 87.62411, 88.27154, 86.84994, 61.68961, + -67.81691, -36.073383, 54.346165, -83.79197, 35.099308, -23.05919, 26.401726, 20.99549, + -68.63421, -93.027596, 20.0895, -16.020033, 57.642673, 8.66057, 39.191364, 29.198711, + -5.9334397, 11.010835, 82.77485, -34.213863, -38.869553, 16.539444, 51.105484, 25.632273, + -55.436813, -26.42026, 77.96095, -59.019154, -82.52756, -94.416176, -83.77591, 46.43875, + 0.7686069, 57.346397, -89.24597, -8.594538, -98.168755, -33.18969, -41.993664, 13.660449, + 50.10378, 9.801906, -4.2520585, 27.210102, 48.8715, -19.44194, 38.652195, 23.77053, + -82.0674, -93.96652, 99.148094, 22.794533, 0.5715625, 0.84766275, 87.92019, 37.35077, + -32.265865, 67.46462, -24.098558, 87.36311, 90.409134, 33.023712, -15.923093, 40.05901, + -12.006578, 31.039108, -63.882004, -73.78517, -24.940235, 30.9098, 31.745, -89.77378, + -46.777866, 58.79768, -24.669464, 96.29413, 61.62126, 45.743416, 38.30191, 71.805405, + -31.20969, 33.56755, -1.926614, 72.13441, -22.292011, -16.355177, 21.689945, 87.95895, + -98.04168, 93.35264, -12.684541, -18.105795, 30.574284, 42.890903, -94.390366, -47.013157, + -98.465126, 28.63009, -83.54015, 86.82799, 0.6768988, 6.070787, 43.308678, 1.8557712, + -73.0521, -90.86948, 43.77232, 68.301056, 66.867775, 97.34002, -59.342876, -51.359367, + 17.27793, 52.223003, -3.9915564, 29.598532, 34.474148, -80.920456, -30.45005, -17.469683, + -67.02992, -34.23075, -35.53944, 61.557327, -66.91338, -94.03176, -45.88021, 97.36409, + 96.45681, -32.885677, 72.40823, -62.28857, 20.948895, 1.259363, -84.97583, 60.83626, + -94.692535, -15.315798, -99.92936, 40.56625, -8.6356325, -7.3984733, 56.255993, -31.700819, + 62.08311, 52.800938, 32.27374, -99.46793, -40.924038, 24.67266, -58.954403, 42.263252, + -72.13501, -58.40316, 14.619292, -43.400642, -82.13468, -47.54976, -42.642033, -8.409653, + 74.90983, 97.76474, -71.152916, 83.61312, -37.22972, 21.405357, -56.848846, 90.63024, + -70.21143, -29.522697, 94.9647, 74.74478, 37.564766, -40.22343, -63.337795, -65.86191, + -48.546135, -58.20052, 36.73888, 67.78194, -43.096832, 94.7046, 9.798892, -79.97487, + -15.868657, -84.753975, 4.8745494, -18.346195, 54.9818, 75.854, 41.797707, -5.673281, + -36.31264, -73.4931, -41.090492, 6.3805137, -73.66098, 85.20992, 91.28027, -73.26658, + -92.18044, 41.29011, 5.5041995, -73.70062, -16.678818, 30.614132, 92.100555, 11.274231, + -37.915485, 34.91591, 36.32971, -37.70164, -23.708878, 19.026278, -41.71216, 67.325356, + 78.23511, -43.154037, 22.667723, 30.742237, -6.086414, 17.191307, 65.828896, -40.83338, + -18.61725, 23.976517, 80.2347, -92.53064, 71.6477, -38.28841, -60.853157, 24.402542}; + +const std::vector reference_output_data = { + -55.431667, -46.952095, 16.357655, 28.008245, -4.7193613, 81.108795, -33.334023, 14.859333, + -39.398083, 33.673332, 13.409595, 13.601912, -15.728818, -53.57022, -9.9313755, -39.922916, + 20.71593, 22.963072, -13.522823, 31.672546, 24.615828, 36.89219, -29.65866, -13.014804, + 20.91112, 54.368, 18.141413, 17.750427, -8.869844, -16.984585, -16.636799, 12.978033, + -12.962048, 13.376387, 23.776978, -23.59151, -18.810696, -27.365314, 18.422699, -0.4828272, + -42.342857, 2.1302667, 11.922464, -8.235632, -39.82988, -45.184032, 46.28369, 4.489258, + 17.493837, -32.964592, -0.55646133, -4.6420527, -28.523571, 41.74006, -36.128933, 7.3906593, + -29.771688, 29.327526, -1.0928774, 5.232649, 22.122757, 9.025103, -1.7341671, -0.7728319}; + +} // namespace mean_float + +class TestDataFloatMean : public TestDataMeanBase +{ +public: + TestDataFloatMean() + { + _input_data = mean_float::input_data; + _reference_output_data = mean_float::reference_output_data; + _test_kernel_model_circle = mean_float::test_kernel_model_circle; + } + + ~TestDataFloatMean() override = default; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_FLOAT_MEAN_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/mean/NegMeanKernel.h b/onert-micro/onert-micro/include/test_models/mean/NegMeanKernel.h new file mode 100644 index 00000000000..409c302015d --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/mean/NegMeanKernel.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_NEG_MEAN_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_NEG_MEAN_KERNEL_H + +#include "TestDataMeanBase.h" + +namespace onert_micro +{ +namespace test_model +{ +namespace neg_input_output_type_mismatch_mean_kernel +{ +/* + * Mean Kernel with input output type mismatch: + * + * Input(1, 8, 8, 4) - Float32 + * | + * Mean + * | + * Output(1, 8, 8, 1) - Int32 + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x98, 0x01, 0x00, 0x00, 0xb4, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff, 0xfc, 0xff, 0xff, 0xff, + 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x6c, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x1b, 0x14, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xd0, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x02, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x1c, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x72, 0x65, 0x64, 0x75, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x5f, 0x69, 0x6e, 0x64, 0x69, 0x63, 0x65, 0x73, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x28, 0x11, 0x00, 0x00, 0x00, + 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, + 0x65, 0x00, 0x00, 0x00}; +} // namespace neg_input_output_type_mismatch_mean_kernel + +class NegTestDataInputOutputTypeMismatchMeanKernel : public NegTestDataBase +{ +public: + NegTestDataInputOutputTypeMismatchMeanKernel() + { + _test_kernel_model_circle = + neg_input_output_type_mismatch_mean_kernel::test_kernel_model_circle; + } + + ~NegTestDataInputOutputTypeMismatchMeanKernel() override = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + +protected: + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_NEG_LOG_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/mean/TestDataMeanBase.h b/onert-micro/onert-micro/include/test_models/mean/TestDataMeanBase.h new file mode 100644 index 00000000000..890b165125a --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/mean/TestDataMeanBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_MEAN_KERNEL_BASE_H +#define ONERT_MICRO_TEST_MODELS_MEAN_KERNEL_BASE_H + +#include "test_models/TestDataBase.h" + +namespace onert_micro +{ +namespace test_model +{ + +template class TestDataMeanBase : public TestDataBase +{ +public: + TestDataMeanBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_MEAN_KERNEL_BASE_H diff --git a/onert-micro/onert-micro/src/execute/kernels/Mean.cpp b/onert-micro/onert-micro/src/execute/kernels/Mean.cpp new file mode 100644 index 00000000000..849a88de382 --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/Mean.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "execute/OMUtils.h" +#include "execute/OMKernelExecutionBuilder.h" +#include "OMStatus.h" +#include "execute/OMRuntimeKernel.h" +#include "core/OMUtils.h" + +#include "core/OMRuntimeShape.h" +#include "PALReduceCommon.h" + +using namespace onert_micro; +using namespace onert_micro::execute; + +namespace +{ + +constexpr uint32_t input1TensorIdx = 0; +constexpr uint32_t input2TensorIdx = 1; +constexpr uint32_t outputTensorIdx = 0; + +template +void reduceMeanGeneric(core::OMRuntimeShape &input_shape, const T *input_data, + core::OMRuntimeShape &axis_shape, const int *axis_data, + core::OMRuntimeShape &output_shape, T *output_data, bool keep_dims) +{ + onert_micro::execute::pal::ReduceGeneric( + input_data, input_shape.dimsData(), input_shape.dimensionsCount(), output_data, axis_data, + axis_shape.dimensionsCount(), + /*init_value=*/T(0), output_shape.flatSize(), + [](const T current, const T in) -> T { return in + current; }); +} + +} // namespace + +namespace onert_micro +{ +namespace execute +{ + +OMStatus execute_kernel_CircleMean(const OMExecuteArgs &execute_args) +{ + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + + const circle::Tensor *input; + const circle::Tensor *axis; + const circle::Tensor *output; + + uint8_t *input_data; + uint8_t *axis_data; + uint8_t *output_data; + + uint16_t input_index = 0; + uint16_t axis_index = 0; + + const circle::ReducerOptions *options; + // Read kernel + { + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + input = runtime_kernel.inputs[input1TensorIdx]; + axis = runtime_kernel.inputs[input2TensorIdx]; + output = runtime_kernel.outputs[outputTensorIdx]; + assert(input != nullptr); + assert(axis != nullptr); + assert(output != nullptr); + + runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + + input_data = runtime_kernel.inputs_data[input1TensorIdx]; + axis_data = runtime_kernel.inputs_data[input2TensorIdx]; + output_data = runtime_kernel.outputs_data[outputTensorIdx]; + assert(input_data != nullptr); + assert(axis_data != nullptr); + assert(output_data != nullptr); + + options = runtime_kernel.first_operator->builtin_options_as_ReducerOptions(); + + input_index = runtime_kernel.inputs_index[input1TensorIdx]; + axis_index = runtime_kernel.inputs_index[input2TensorIdx]; + } + + OMStatus status; + + core::OMRuntimeShape input_shape(input); + core::OMRuntimeShape axis_shape(axis); + core::OMRuntimeShape output_shape(output); + + switch (input->type()) + { +#ifndef DIS_FLOAT + case circle::TensorType_FLOAT32: + onert_micro::execute::pal::Mean( + input_shape.dimsData(), core::utils::castInputData(input_data), + input_shape.dimensionsCount(), core::utils::castOutputData(output_data), + output_shape.flatSize(), core::utils::castInputData(axis_data), + axis_shape.dimensionsCount()); + + break; +#endif // DIS_FLOAT + case circle::TensorType_INT32: + break; + case circle::TensorType_INT64: + break; + default: + assert(false && "Unsupported type"); + } + + return status; +} + +} // namespace execute +} // namespace onert_micro diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/Mean.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/Mean.test.cpp new file mode 100644 index 00000000000..024153b95c0 --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/tests/Mean.test.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "execute/OMTestUtils.h" +#include "test_models/mean/FloatMeanKernel.h" +#include "test_models/mean/NegMeanKernel.h" + +namespace onert_micro +{ +namespace execute +{ +namespace testing +{ + +using namespace testing; + +class MeanTest : public ::testing::Test +{ + // Do nothing +}; + +template +std::vector checkMeanKernel(test_model::TestDataMeanBase *test_data_base) +{ + onert_micro::OMInterpreter interpreter; + onert_micro::OMConfig config; + + interpreter.importModel(reinterpret_cast(test_data_base->get_model_ptr()), config); + + interpreter.reset(); + interpreter.allocateInputs(); + + T *input_data = reinterpret_cast(interpreter.getInputDataAt(0)); + + std::copy(test_data_base->get_input_data_by_index(0).begin(), + test_data_base->get_input_data_by_index(0).end(), input_data); + interpreter.run(config); + + T *output_data = reinterpret_cast(interpreter.getOutputDataAt(0)); + const size_t num_elements = interpreter.getOutputSizeAt(0); + std::vector output_data_vector(output_data, output_data + num_elements); + return output_data_vector; +} + +TEST_F(MeanTest, Float_P) +{ + test_model::TestDataFloatMean test_data_kernel; + std::vector output_data_vector = checkMeanKernel(&test_data_kernel); + EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0)); +} + +TEST_F(MeanTest, Input_output_type_mismatch_NEG) +{ + test_model::NegTestDataInputOutputTypeMismatchMeanKernel test_data_kernel; + EXPECT_DEATH(checkNEGSISOKernel(&test_data_kernel), ""); +} + +} // namespace testing +} // namespace execute +} // namespace onert_micro diff --git a/onert-micro/onert-micro/src/import/kernels/Mean.cpp b/onert-micro/onert-micro/src/import/kernels/Mean.cpp new file mode 100644 index 00000000000..218b2b6c4c5 --- /dev/null +++ b/onert-micro/onert-micro/src/import/kernels/Mean.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "import/helpers/OMConfigureTISOKernel.h" + +using namespace onert_micro; +using namespace onert_micro::core; + +namespace +{ + +constexpr uint32_t inputTensorIdx = 0; +constexpr uint32_t axisTensorIdx = 1; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +namespace onert_micro +{ +namespace import +{ + +OMStatus configure_kernel_CircleMean(const OMConfigureArgs &config_args) +{ + OMRuntimeContext &runtime_context = config_args.runtime_context; + uint16_t op_index = config_args.kernel_index; + + onert_micro::execute::OMRuntimeKernel runtime_kernel; + + OMStatus status = runtime_kernel.readKernel(op_index, runtime_context); + if (status != Ok) + return status; + + const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx]; + const circle::Tensor *axis = runtime_kernel.inputs[axisTensorIdx]; + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input != nullptr); + assert(axis != nullptr); + assert(output != nullptr); + + status = utils::checkCondition(axis->type() == circle::TensorType_INT32); + status = utils::checkCondition(input->type() == output->type()); + + return status; +} + +} // namespace import +} // namespace onert_micro