Skip to content

Commit

Permalink
[onert-micro] Support S8 MaxPool2D (#13186)
Browse files Browse the repository at this point in the history
This pr adds supporting of s8 + cmsis_nn.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem authored Jun 14, 2024
1 parent ff12cae commit 2873aaf
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D)
#/*REGISTER_KERNEL(MIRROR_PAD, MirrorPad)*/
#/*REGISTER_KERNEL(MAXIMUM, Maximum)*/
#/*REGISTER_KERNEL(MEAN, Mean)*/
#/*REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)*/
REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)
#/*REGISTER_KERNEL(MINIMUM, Minimum)*/
#/*REGISTER_KERNEL(SHAPE, Shape)*/
#/*REGISTER_KERNEL(NOT_EQUAL, NotEqual)*/
Expand Down
81 changes: 81 additions & 0 deletions onert-micro/onert-micro/include/pal/cmsisnn/PALMaxPool2D.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_MAX_POOL_2D_H
#define ONERT_MICRO_EXECUTE_PAL_MAX_POOL_2D_H

#include "PALMaxPool2DCommon.h"

#include <arm_nnfunctions.h>

namespace onert_micro
{
namespace execute
{
namespace pal
{

OMStatus MaxPool(const core::Pool2DParams &params, const core::OMRuntimeShape &input_shape,
const int8_t *input_data, const core::OMRuntimeShape &output_shape,
int8_t *output_data)
{
cmsis_nn_dims input_dims;
cmsis_nn_dims output_dims;
cmsis_nn_pool_params pool_params;
cmsis_nn_dims filter_dims;
cmsis_nn_context ctx;

const int depth = input_shape.dims(3);
const int output_width = output_shape.dims(2);

input_dims.n = 1;
input_dims.h = input_shape.dims(1);
input_dims.w = input_shape.dims(2);
input_dims.c = depth;

output_dims.n = 1;
output_dims.h = output_shape.dims(1);
output_dims.w = output_width;
output_dims.c = depth;

pool_params.stride.h = params.stride_h;
pool_params.stride.w = params.stride_w;
pool_params.padding.h = params.pad_h;
pool_params.padding.w = params.pad_w;
pool_params.activation.min = params.quantized_activation_min;
pool_params.activation.max = params.quantized_activation_max;

filter_dims.n = 1;
filter_dims.h = params.filter_h;
filter_dims.w = params.filter_w;
filter_dims.c = 1;

auto res = arm_max_pool_s8(&ctx, &pool_params, &input_dims, input_data, &filter_dims,
&output_dims, output_data);

assert(res == ARM_CMSIS_NN_SUCCESS);
if (res != ARM_CMSIS_NN_SUCCESS)
return CmsisNNError;

return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_MAX_POOL_2D_H
57 changes: 53 additions & 4 deletions onert-micro/onert-micro/include/pal/mcu/PALMaxPool2D.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,60 @@ namespace execute
namespace pal
{

OMStatus MaxPool(const core::Pool2DParams &, const core::OMRuntimeShape &, const uint8_t *,
const core::OMRuntimeShape &, uint8_t *, circle::TensorType)
OMStatus MaxPool(const core::Pool2DParams &params, const core::OMRuntimeShape &input_shape,
const int8_t *input_data, const core::OMRuntimeShape &output_shape,
int8_t *output_data)
{
assert(false && "Not impl yet");
return UnsupportedType;
assert(input_shape.dimensionsCount() == 4);
assert(output_shape.dimensionsCount() == 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int input_height = input_shape.dims(1);
const int input_width = input_shape.dims(2);
const int output_height = output_shape.dims(1);
const int output_width = output_shape.dims(2);
const int stride_height = params.stride_h;
const int stride_width = params.stride_w;
const int pad_w = params.pad_w;
const int pad_h = params.pad_h;
const int filter_h = params.filter_h;
const int filter_w = params.filter_w;
for (int batch = 0; batch < batches; ++batch)
{
for (int out_y = 0; out_y < output_height; ++out_y)
{
for (int out_x = 0; out_x < output_width; ++out_x)
{
for (int channel = 0; channel < depth; ++channel)
{
const int in_x_origin = (out_x * stride_width) - pad_w;
const int in_y_origin = (out_y * stride_height) - pad_h;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end = std::min(filter_w, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end = std::min(filter_h, input_height - in_y_origin);
int8_t max = std::numeric_limits<int8_t>::lowest();
for (int filter_y = filter_y_start; filter_y < filter_y_end; ++filter_y)
{
for (int filter_x = filter_x_start; filter_x < filter_x_end; ++filter_x)
{
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
max = std::max(
max, input_data[offset(input_shape.dimsData(), batch, in_y, in_x, channel)]);
}
}
max = std::max<int8_t>(max, params.quantized_activation_min);
max = std::min<int8_t>(max, params.quantized_activation_max);
output_data[offset(output_shape.dimsData(), batch, out_y, out_x, channel)] =
static_cast<int8_t>(max);
}
}
}
}
return Ok;
}

} // namespace pal
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TEST_MODELS_QUANT_MAX_POOL_KERNEL_H
#define ONERT_MICRO_TEST_MODELS_QUANT_MAX_POOL_KERNEL_H

#include "TestDataMaxPool2DBase.h"

namespace onert_micro
{
namespace test_model
{
namespace s8_max_pool
{

/*
* S8 MaxPool2D Kernel:
*
* Input(1, 8, 8, 1) - Int8
* |
* MaxPool2D - Int8
* |
* Output(1, 7, 7, 1) - Int8
*/

const unsigned char test_kernel_model_circle[] = {
0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x2c, 0x00, 0x00, 0x00, 0xdc, 0x01, 0x00, 0x00, 0xf8, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff,
0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00,
0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
0x7c, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x00, 0x00, 0x0e, 0x00, 0x16, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x08, 0x00,
0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x1c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00,
0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x17, 0x00, 0x10, 0x00, 0x0c, 0x00,
0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x84, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x92, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00,
0x48, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x44, 0x00, 0x00, 0x00,
0x84, 0xff, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x6f, 0x12, 0x83, 0x3b, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00,
0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x13, 0x00, 0x0c, 0x00,
0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x54, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00,
0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00,
0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x6f, 0x12, 0x83, 0x3b, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00,
0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11,
0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63,
0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};

const std::vector<int8_t> input_data = {
-10, -5, -5, -6, -8, -5, 6, -10, -4, 5, -4, -3, 5, 0, -8, -6, 0, 7, 1, -5, -10, 8,
0, -2, -9, 5, -1, -7, -3, 3, 6, -1, -8, -9, -1, 8, -2, -3, 1, -8, -1, 3, 2, 5,
1, -8, -1, -6, 6, -6, 5, 9, -4, 5, -10, -8, -3, -9, 9, 4, 3, 0, 0, -10};

const std::vector<int8_t> reference_output_data = {
5, 5, -3, 5, 5, 6, 6, 7, 7, 1, 5, 8, 8, 0, 7, 7, 1, -3, 8, 8, 6, 5, 5, 8, 8,
3, 6, 6, 3, 3, 8, 8, 1, 1, 1, 6, 5, 9, 9, 5, 5, -1, 6, 9, 9, 9, 5, 5, 0};

} // namespace s8_max_pool

class TestDataS8MaxPool2D : public TestDataMaxPool2DBase<int8_t>
{
public:
TestDataS8MaxPool2D()
{
_input_data = s8_max_pool::input_data;
_reference_output_data = s8_max_pool::reference_output_data;
_test_kernel_model_circle = s8_max_pool::test_kernel_model_circle;
}

~TestDataS8MaxPool2D() override = default;
};

} // namespace test_model
} // namespace onert_micro

#endif // ONERT_MICRO_TEST_MODELS_QUANT_MAX_POOL_KERNEL_H
21 changes: 21 additions & 0 deletions onert-micro/onert-micro/src/execute/kernels/MaxPool2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,27 @@ OMStatus onert_micro::execute::execute_kernel_CircleMaxPool2D(const OMExecuteArg
}
break;
#endif // DIS_FLOAT
#ifndef DIS_QUANT
case circle::TensorType_INT8:
{
assert(output->quantization() != nullptr);
assert(output->quantization()->scale() != nullptr);
assert(output->quantization()->scale()->size() == 1);
const auto output_scale = output->quantization()->scale()->operator[](0);

assert(output->quantization()->zero_point() != nullptr);
assert(output->quantization()->zero_point()->size() == 1);
const auto output_zp = output->quantization()->zero_point()->operator[](0);

calculateActivationRangeQuantized(
options->fused_activation_function(), output_zp, output_scale, output->type(),
&params.quantized_activation_min, &params.quantized_activation_max);
status = pal::MaxPool(params, input_shape, core::utils::castInputData<int8_t>(input_data),
core::OMRuntimeShape(output),
core::utils::castOutputData<int8_t>(output_data));
}
break;
#endif // DIS_QUANT
default:
{
status = UnsupportedType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "execute/OMTestUtils.h"
#include "test_models/maxpool2d/FloatMaxPool2DKernel.h"
#include "test_models/maxpool2d/NegMaxPool2DKernel.h"
#include "test_models/maxpool2d/QuantMaxPool2DKernel.h"

namespace onert_micro
{
Expand All @@ -40,6 +41,14 @@ TEST_F(MaxPool2DTest, Float_P)
EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0));
}

TEST_F(MaxPool2DTest, S8_P)
{
onert_micro::test_model::TestDataS8MaxPool2D test_data_kernel;
std::vector<int8_t> output_data_vector =
onert_micro::execute::testing::checkKernel<int8_t>(1, &test_data_kernel);
EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0));
}

TEST_F(MaxPool2DTest, Input_output_type_mismatch_NEG)
{
onert_micro::test_model::NegTestDataInputOutputTypeMismatchMaxPool2DKernel test_data_kernel;
Expand Down
11 changes: 11 additions & 0 deletions onert-micro/onert-micro/src/import/kernels/MaxPool2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,16 @@ OMStatus onert_micro::import::configure_kernel_CircleMaxPool2D(const OMConfigure
return UnsupportedType;
}

// Check quantization params
if (output->quantization() == nullptr)
{
return NoQuantization;
}

if (output->quantization()->scale()->size() != 1)
{
return UnsupportedType;
}

return status;
}

0 comments on commit 2873aaf

Please sign in to comment.