Skip to content

Commit

Permalink
[onert-micro] Support S8 Mul (#13204)
Browse files Browse the repository at this point in the history
This pr adds supporting of s8 + cmsis_nn.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem authored Jun 14, 2024
1 parent 0fd8669 commit 6f1d1d3
Show file tree
Hide file tree
Showing 11 changed files with 555 additions and 9 deletions.
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/core/OMRuntimeShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace onert_micro
namespace core
{

static constexpr int maxTensorShapeSize = 5;
static constexpr int maxTensorShapeSize = 6;

class OMRuntimeShape
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D)
#/*REGISTER_KERNEL(LOGICAL_OR, LogicalOr)*/
#/*REGISTER_KERNEL(LEAKY_RELU, LeakyRelu)*/
#/*REGISTER_KERNEL(LOG_SOFTMAX, LogSoftmax)*/
#/*REGISTER_KERNEL(MUL, Mul)*/
REGISTER_KERNEL(MUL, Mul)
#/*REGISTER_KERNEL(MIRROR_PAD, MirrorPad)*/
#/*REGISTER_KERNEL(MAXIMUM, Maximum)*/
#/*REGISTER_KERNEL(MEAN, Mean)*/
Expand Down
52 changes: 52 additions & 0 deletions onert-micro/onert-micro/include/pal/cmsisnn/PALMul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_MUL_H
#define ONERT_MICRO_EXECUTE_PAL_MUL_H

#include "PALMulCommon.h"
#include "PALUtils.h"

#include "arm_nnfunctions.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{

OMStatus Mul(const core::ArithmeticQuantParams &params, const uint32_t flat_size,
const int8_t *input1_data, const int8_t *input2_data, int8_t *output_data)
{
auto status = arm_elementwise_mul_s8(
input1_data, input2_data, params.input1_offset, params.input2_offset, output_data,
params.output_offset, params.output_multiplier, params.output_shift,
params.quantized_activation_min, params.quantized_activation_max, flat_size);
assert(status == ARM_CMSIS_NN_SUCCESS);

if (status != ARM_CMSIS_NN_SUCCESS)
return UnknownError;

return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_ADD_H
93 changes: 93 additions & 0 deletions onert-micro/onert-micro/include/pal/common/PALMulCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ namespace execute
{
namespace pal
{
namespace
{
// Maximum dimension supported by the broadcast mul operation.
constexpr int kMaxMulBroadcastDim = 6;
} // namespace

template <typename T>
OMStatus Mul(const core::BinaryArithmeticBroadcastParams &params, const int flat_size,
Expand All @@ -46,6 +51,94 @@ OMStatus BroadcastMul4DSlow(const core::BinaryArithmeticBroadcastParams &params,
return Ok;
}

template <typename T>
OMStatus BroadcastMul6DSlow(const core::ArithmeticQuantParams &params,
const core::OMRuntimeShape &input1_shape, const T *input1_data,
const core::OMRuntimeShape &input2_shape, const T *input2_data,
const core::OMRuntimeShape &output_shape, T *output_data)
{
NdArrayDesc<kMaxMulBroadcastDim> desc1{};
NdArrayDesc<kMaxMulBroadcastDim> desc2{};
// The input shapes are extended as part of NdArrayDesc initialization.
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
const core::OMRuntimeShape extended_output_shape =
core::OMRuntimeShape::extendedShape(kMaxMulBroadcastDim, output_shape);
// Cache output shape dimensions.
int32_t extended_output_shape_dims[kMaxMulBroadcastDim];
std::memcpy(extended_output_shape_dims, extended_output_shape.dimsData(),
sizeof(extended_output_shape_dims));

size_t input1_offset_a = 0;
size_t input2_offset_a = 0;
size_t output_offset_a = 0;
for (int a = 0; a < extended_output_shape_dims[0]; ++a)
{
size_t input1_offset_d = input1_offset_a;
size_t input2_offset_d = input2_offset_a;
size_t output_offset_d = output_offset_a;
for (int d = 0; d < extended_output_shape_dims[1]; ++d)
{
size_t input1_offset_b = input1_offset_d;
size_t input2_offset_b = input2_offset_d;
size_t output_offset_b = output_offset_d;
for (int b = 0; b < extended_output_shape_dims[2]; ++b)
{
size_t input1_offset_y = input1_offset_b;
size_t input2_offset_y = input2_offset_b;
size_t output_offset_y = output_offset_b;
for (int y = 0; y < extended_output_shape_dims[3]; ++y)
{
size_t input1_offset_x = input1_offset_y;
size_t input2_offset_x = input2_offset_y;
size_t output_offset_x = output_offset_y;
for (int x = 0; x < extended_output_shape_dims[4]; ++x)
{
size_t input1_offset_c = input1_offset_x;
size_t input2_offset_c = input2_offset_x;
size_t output_offset_c = output_offset_x;
for (int c = 0; c < extended_output_shape_dims[5]; ++c)
{
const int32_t input1_val = params.input1_offset + input1_data[input1_offset_c];
const int32_t input2_val = params.input2_offset + input2_data[input2_offset_c];
const int32_t unclamped_result =
params.output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
params.output_multiplier,
params.output_shift);
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, unclamped_result));
output_data[output_offset_c] = static_cast<T>(clamped_output);
input1_offset_c += desc1.strides[5];
input2_offset_c += desc2.strides[5];
++output_offset_c;
}
input1_offset_x += desc1.strides[4];
input2_offset_x += desc2.strides[4];
output_offset_x += extended_output_shape_dims[5];
}
input1_offset_y += desc1.strides[3];
input2_offset_y += desc2.strides[3];
output_offset_y += extended_output_shape_dims[4] * extended_output_shape_dims[5];
}
input1_offset_b += desc1.strides[2];
input2_offset_b += desc2.strides[2];
output_offset_b += extended_output_shape_dims[3] * extended_output_shape_dims[4] *
extended_output_shape_dims[5];
}
input1_offset_d += desc1.strides[1];
input2_offset_d += desc2.strides[1];
output_offset_d += extended_output_shape_dims[2] * extended_output_shape_dims[3] *
extended_output_shape_dims[4] * extended_output_shape_dims[5];
}
input1_offset_a += desc1.strides[0];
input2_offset_a += desc2.strides[0];
output_offset_a += extended_output_shape_dims[1] * extended_output_shape_dims[2] *
extended_output_shape_dims[3] * extended_output_shape_dims[4] *
extended_output_shape_dims[5];
}
return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro
Expand Down
30 changes: 30 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/PALMul.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,34 @@

#include "PALMulCommon.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{

template <typename InputType, typename OutputType>
OMStatus Mul(const core::ArithmeticQuantParams &params, uint32_t size, const InputType *input1_data,
const InputType *input2_data, OutputType *output_data)
{
for (int i = 0; i < size; ++i)
{
const int32_t input1_val = params.input1_offset + input1_data[i];
const int32_t input2_val = params.input2_offset + input2_data[i];
const int32_t unclamped_result =
params.output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
params.output_multiplier,
params.output_shift);
const int32_t clamped_output = std::min(
params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
output_data[i] = static_cast<OutputType>(clamped_output);
}
return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_MUL_H
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ namespace add_int8_with_broadcasting
/*
* Add Kernel:
*
* Input_1(1, 4, 4, 3) Input_2(1, 4, 4, 3)
* Input_1(1, 4, 4, 3) Input_2(1, 4, 4, 1)
* \ /
* Add(with broadcast)
* |
Expand Down
71 changes: 71 additions & 0 deletions onert-micro/onert-micro/include/test_models/mul/NegMulKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,61 @@ const unsigned char test_kernel_model_circle[] = {

} // namespace input_1_wrong_type

namespace neg_mul_no_scale_output
{
/*
* Quantize Mul Kernel with output without scale:
*
* Input_1(1, 4, 4, 3) - Int8 Input_2(1, 4, 4, 3) - Int8
* \ /
* Mul(with broadcast)
* |
* Output(1, 4, 4, 3) - no scale and zero_point
*/
const unsigned char test_kernel_model_circle[] = {
0x1c, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x30, 0x02, 0x00, 0x00, 0x3c, 0x02, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0xe4, 0xfd, 0xff, 0xff, 0xe8, 0xfd, 0xff, 0xff, 0xec, 0xfd, 0xff, 0xff,
0xf0, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00,
0x68, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x10, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0b,
0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x5c, 0xfe, 0xff, 0xff,
0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x00, 0x00,
0x68, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2a, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00,
0x3c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x38, 0x00, 0x00, 0x00,
0x1c, 0xff, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x43, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc3, 0x03, 0x00, 0x00, 0x00,
0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x8a, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00,
0x4c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x4c, 0x00, 0x00, 0x00,
0x7c, 0xff, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x58, 0x39, 0xb4, 0x3c, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x43, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc3, 0x04, 0x00, 0x00, 0x00,
0x69, 0x66, 0x6d, 0x32, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x18, 0x00, 0x14, 0x00, 0x13, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x20, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09,
0x54, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0xf4, 0xfd, 0x54, 0x3c, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x43,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc3, 0x04, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x31,
0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d,
0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};

} // namespace neg_mul_no_scale_output

namespace input_2_wrong_type
{

Expand Down Expand Up @@ -203,6 +258,22 @@ class NegTestDataInt16TypeMul : public NegTestDataBase
const unsigned char *_test_kernel_model_circle;
};

class NegTestQuantMulNoScaleKernel : public NegTestDataBase
{
public:
NegTestQuantMulNoScaleKernel()
{
_test_kernel_model_circle = neg_mul_no_scale_output::test_kernel_model_circle;
}

~NegTestQuantMulNoScaleKernel() override = default;

const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; }

protected:
const unsigned char *_test_kernel_model_circle;
};

} // namespace test_model
} // namespace onert_micro

Expand Down
Loading

0 comments on commit 6f1d1d3

Please sign in to comment.