Skip to content

Commit

Permalink
[onert-micro] Support SquaredDifference float kernel
Browse files Browse the repository at this point in the history
This commit supports SquaredDifference float kernel.

ONE-DCO-1.0-Signed-off-by: Vyacheslav Bazhenov <[email protected]>
  • Loading branch information
Vyacheslav Bazhenov committed Jun 6, 2024
1 parent a8f50cc commit 0cb400d
Show file tree
Hide file tree
Showing 13 changed files with 677 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_TISO_KERNEL_H
#define ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_TISO_KERNEL_H

#include "import/OMKernelConfigureBuilder.h"
#include "core/OMUtils.h"
#include "OMStatus.h"
#include "execute/OMRuntimeKernel.h"

namespace onert_micro
{
namespace import
{
namespace helpers
{
OMStatus configure_TISO_kernel(const OMConfigureArgs &config_args);

} // namespace helpers
} // namespace import
} // namespace onert_micro

#endif // ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_TISO_KERNEL_H
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ template <typename T> struct DivFn
{
T operator()(T lhs, T rhs) { return lhs / rhs; }
};

template <typename T> struct SquaredDifferenceFn
{
T operator()(T lhs, T rhs) { return (lhs - rhs) * (lhs - rhs); }
};
template <typename T, typename Fn>
OMStatus ArithmeticOp(const core::BinaryArithmeticBroadcastParams &params, const int flat_size,
const T *input1_data, const T *input2_data, T *output_data)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_SQUARED_DIFFERENCE_COMMON_H
#define ONERT_MICRO_EXECUTE_PAL_SQUARED_DIFFERENCE_COMMON_H

#include "PALArithmeticOpCommon.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{
template <typename T>
OMStatus SquaredDifference(const core::BinaryArithmeticBroadcastParams &params, const int flat_size,
const T *input1_data, const T *input2_data, T *output_data)
{
ArithmeticOp<T, SquaredDifferenceFn<T>>(params, flat_size, input1_data, input2_data, output_data);
return Ok;
}

template <typename T>
OMStatus
BroadcastSquaredDifference4DSlow(const core::BinaryArithmeticBroadcastParams &params,
const core::OMRuntimeShape &input1_shape, const T *input1_data,
const core::OMRuntimeShape &input2_shape, const T *input2_data,
const core::OMRuntimeShape &output_shape, T *output_data)
{
BroadcastArithmeticOp4DSlow<T, SquaredDifferenceFn<T>>(
params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data);
return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_SQUARED_DIFFERENCE_COMMON_H
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/pal/mcu/KernelsToBuild.lst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)
#/*REGISTER_KERNEL(SHAPE, Shape)*/
REGISTER_KERNEL(NOT_EQUAL, NotEqual)
REGISTER_KERNEL(SIN, Sin)
#/*REGISTER_KERNEL(SQUARED_DIFFERENCE, SquaredDifference)*/
REGISTER_KERNEL(SQUARED_DIFFERENCE, SquaredDifference)
#/*REGISTER_KERNEL(SLICE, Slice)*/
REGISTER_KERNEL(SUB, Sub)
#/*REGISTER_KERNEL(SPLIT, Split)*/
Expand Down
33 changes: 33 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/PALSquaredDifference.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_SQUARED_DIFFERENCE_H
#define ONERT_MICRO_EXECUTE_PAL_SQUARED_DIFFERENCE_H

#include "PALSquaredDifferenceCommon.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_SQUARED_DIFFERENCE_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TEST_MODELS_SQUARED_DIFFERENCE_KERNEL_FLOAT_H
#define ONERT_MICRO_TEST_MODELS_SQUARED_DIFFERENCE_KERNEL_FLOAT_H

#include "TestDataSquaredDifferenceBase.h"

namespace onert_micro
{
namespace test_model
{
namespace squared_difference_float_with_broadcasting
{

/*
* SquaredDifference Kernel:
*
* Input_1(1, 4, 4, 3) Input_2(1, 4, 4, 3)
* \ /
* SquaredDifference(with broadcast)
* |
* Output(1, 4, 4, 3)
*/
const unsigned char test_kernel_model_circle[] = {
0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x30, 0x00, 0x00, 0x00, 0x6c, 0x01, 0x00, 0x00, 0x88, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x88, 0xff, 0xff, 0xff, 0x8c, 0xff, 0xff, 0xff, 0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff,
0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00,
0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00,
0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x10, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00,
0x34, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa4, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0xd0, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x32, 0x00, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x31, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00,
0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63,
0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63,
0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};

const std::vector<float> input1_data = {
33.44, 44.74834, 42.31354, 17.271736, 46.138657, 42.50795, 9.75354, 11.343542,
25.705894, 36.687202, 17.357473, 29.17235, 34.9576, 34.23016, 22.528538, 25.484097,
38.542297, 32.78322, 31.368523, 52.47518, 29.052322, 35.70844, 19.942907, 30.840899,
45.22654, 22.581013, 24.37784, -7.1113133, 44.58411, 42.722954, 20.67068, 47.006798,
40.08999, 25.972889, 51.532227, 15.329674, 22.725258, 37.83095, 24.72808, 32.937607,
39.836365, 11.756248, 8.707649, 42.69889, 22.493517, 26.656416, 19.543903, 38.58926};
const std::vector<float> input2_data = {
45.496315, 47.277058, 43.62587, 25.520176, 26.66666, 21.242783, 38.55497, 29.928194,
32.247902, 5.110588, 35.261402, 35.692963, 30.808405, 30.916706, 38.445484, 32.367344,
20.506172, 21.414955, 39.93972, 34.694054, 36.724403, 17.250431, 34.863686, 32.42676,
31.634842, 44.39825, 31.116629, 18.85633, 19.150063, 33.974716, 26.233631, 29.054287,
29.618658, 38.05911, 54.44181, 28.360674, 15.703876, 39.576363, 29.899979, 26.719788,
56.03285, 50.236435, 41.24733, 31.759392, 47.402603, 34.884254, 38.37292, 18.049467};
const std::vector<float> reference_output_data = {
145.35477, 6.3944097, 1.7222056, 68.03676, 379.15863, 452.20734, 829.52234, 345.38928,
42.797863, 997.0826, 320.55066, 42.51839, 17.215816, 10.978975, 253.34918, 47.379093,
325.30182, 129.23741, 73.46542, 316.16852, 58.86083, 340.69806, 222.62962, 2.5149617,
184.73425, 475.99188, 45.41127, 674.31854, 646.8908, 76.53166, 30.946436, 322.29263,
109.648766, 146.0767, 8.465679, 169.80696, 49.29981, 3.0464592, 26.748528, 38.661274,
262.32608, 1480.7247, 1058.8308, 119.67264, 620.4626, 67.69733, 354.5319, 421.8831};

} // namespace squared_difference_float_with_broadcasting

class TestDataFloatSquaredDifference : public TestDataSquaredDifferenceBase<float>
{
public:
TestDataFloatSquaredDifference()
{

_input1_data = squared_difference_float_with_broadcasting::input1_data;
_input2_data = squared_difference_float_with_broadcasting::input2_data;
_reference_output_data = squared_difference_float_with_broadcasting::reference_output_data;
_test_kernel_model_circle =
squared_difference_float_with_broadcasting::test_kernel_model_circle;
}

~TestDataFloatSquaredDifference() override = default;
};

} // namespace test_model
} // namespace onert_micro

#endif // ONERT_MICRO_TEST_MODELS_SQUARED_DIFFERENCE_KERNEL_FLOAT_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TEST_MODELS_NEG_SQUARED_DIFFERENCE_KERNEL_H
#define ONERT_MICRO_TEST_MODELS_NEG_SQUARED_DIFFERENCE_KERNEL_H

#include "TestDataSquaredDifferenceBase.h"

namespace onert_micro
{
namespace test_model
{
namespace inputs_type_mismatch
{

/*
* SquaredDifference Kernel with input type mismatch:
*
* Input_1(2, 5) - Float Input_2(2, 1) - Int
* \ /
* SquaredDifference
* |
* Output(2, 5)
*/
const unsigned char test_kernel_model_circle[] = {
0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x30, 0x00, 0x00, 0x00, 0x7c, 0x01, 0x00, 0x00, 0x98, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x88, 0xff, 0xff, 0xff, 0x8c, 0xff, 0xff, 0xff, 0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff,
0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00,
0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00,
0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x10, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00,
0x40, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x94, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x32, 0x00, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x31, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00,
0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63,
0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63,
0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};

} // namespace inputs_type_mismatch

class NegTestDataInputsTypeMismatchSquaredDifference : public NegTestDataBase
{
public:
NegTestDataInputsTypeMismatchSquaredDifference()
{
_test_kernel_model_circle = inputs_type_mismatch::test_kernel_model_circle;
}

~NegTestDataInputsTypeMismatchSquaredDifference() override = default;

const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; }

protected:
const unsigned char *_test_kernel_model_circle;
};

} // namespace test_model
} // namespace onert_micro

#endif // ONERT_MICRO_TEST_MODELS_NEG_SQUARED_DIFFERENCE_KERNEL_H
Loading

0 comments on commit 0cb400d

Please sign in to comment.