Skip to content

Commit

Permalink
[onert-mico] Use CMSIS_NN for Relu,Relu6
Browse files Browse the repository at this point in the history
Enable CMSIS_NN for Relu, Relu6

ONE-DCO-1.0-Signed-off-by: Chunseok Lee <[email protected]>
  • Loading branch information
chunseoklee committed Aug 16, 2024
1 parent 6fafda7 commit afc0667
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,5 @@ REGISTER_KERNEL(SOFTMAX, Softmax)
#/*REGISTER_KERNEL(ZEROS_LIKE, ZerosLike)*/
#/*REGISTER_KERNEL(SQUEEZE, Squeeze)*/
#/*REGISTER_KERNEL(UNPACK, Unpack)*/
REGISTER_KERNEL(RELU, Relu)
REGISTER_KERNEL(RELU6, Relu6)
65 changes: 65 additions & 0 deletions onert-micro/onert-micro/include/pal/cmsisnn/PALRelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_RELU_H
#define ONERT_MICRO_EXECUTE_PAL_RELU_H

#include "PALReluCommon.h"

#include <arm_nnfunctions.h>

namespace onert_micro
{
namespace execute
{
namespace pal
{

template <>
OMStatus ReLUCommon<int8_t>(const int flat_size, const int8_t *input_data, int8_t *output_data,
const float alpha, const bool is_relu_6)
{
// 1. Relu
if (is_relu_6 == false && alpha == 0)
{
memcpy(output_data, input_data, flat_size);
arm_relu_q7(output_data, flat_size);
}
// 2. Relu6
else if (is_relu_6 && alpha == 0)
{
memcpy(output_data, input_data, flat_size);
arm_relu6_s8(output_data, flat_size);
}
// 3. Leaky_Relu not supported by cmsis_nn
else if (alpha != 0)
{
for (int i = 0; i < flat_size; i++)
{
const int8_t val = input_data[i];
int8_t result = val > 0 ? val : val * alpha;
output_data[i] = result;
}
}

return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro
#endif // ONERT_MICRO_EXECUTE_PAL_RELU_H
11 changes: 6 additions & 5 deletions onert-micro/onert-micro/include/pal/common/PALReluCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ namespace execute
namespace pal
{

inline OMStatus ReLUCommon(const int flat_size, const float *input_data, float *output_data,
const float alpha, const bool is_relu_6)
template <typename Type>
OMStatus ReLUCommon(const int flat_size, const Type *input_data, Type *output_data,
const float alpha, const bool is_relu_6)
{
const float relu_6_value = 6.0f;
const Type relu_6_value = 6.0f;
for (int i = 0; i < flat_size; i++)
{
const float val = input_data[i];
float result = val > 0 ? val : val * alpha;
const Type val = input_data[i];
Type result = val > 0 ? val : val * alpha;
result = is_relu_6 ? (result > relu_6_value ? relu_6_value : result) : result;
output_data[i] = result;
}
Expand Down
24 changes: 24 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/PALRelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_RELU_H
#define ONERT_MICRO_EXECUTE_PAL_RELU_H

#include "PALReluCommon.h"
#include "PALUtils.h"

#endif // ONERT_MICRO_EXECUTE_PAL_RELU_H
19 changes: 18 additions & 1 deletion onert-micro/onert-micro/src/execute/kernels/ReluCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

#include "execute/kernels/ReluCommon.h"
#include "PALReluCommon.h"
#include "PALRelu.h"

using namespace onert_micro;
using namespace onert_micro::execute;
Expand Down Expand Up @@ -86,6 +86,23 @@ OMStatus onert_micro::execute::execute_relu_common(const OMExecuteArgs &execute_
}
break;
#endif // DIS_FLOAT
#ifndef DIS_QUANT
case circle::TensorType_INT8:
{
core::OMRuntimeShape input_shape(input);
core::OMRuntimeShape output_shape(output);

const auto *input_data_int8 = core::utils::castInputData<int8_t>(input_data);
auto *output_data_int8 = core::utils::castOutputData<int8_t>(output_data);

assert(output_data_int8);
const int flat_size = input_shape.flatSize();

status = pal::ReLUCommon(flat_size, input_data_int8, output_data_int8, alpha, is_relu_6);
}
break;
#endif // DIS_QUANT

default:
{
status = UnsupportedType;
Expand Down

0 comments on commit afc0667

Please sign in to comment.