Skip to content

Commit

Permalink
[onert-micro] Reduce import Math ops code duplication (#13309)
Browse files Browse the repository at this point in the history
This pr reduces import code duplication for Math ops.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>

Co-authored-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem and Artem Balyshev authored Jun 27, 2024
1 parent 6482d89 commit e234798
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 252 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,43 @@ constexpr uint32_t outputTensorIdx = 0;

OMStatus onert_micro::import::helpers::configure_SISO_kernel(const OMConfigureArgs &config_args)
{
{
OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;
OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;

onert_micro::execute::OMRuntimeKernel runtime_kernel;
onert_micro::execute::OMRuntimeKernel runtime_kernel;

OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
if (status != Ok)
return status;
OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
if (status != Ok)
return status;

const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];
const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
assert(output != nullptr);
assert(input != nullptr);
assert(output != nullptr);

status = utils::checkCondition(input->type() == output->type());
if (status != Ok)
return status;
status = utils::checkCondition(input->type() == output->type());
if (status != Ok)
return status;

OMRuntimeShape input_shape(input);
OMRuntimeShape output_shape(output);
OMRuntimeShape input_shape(input);
OMRuntimeShape output_shape(output);

status = utils::checkCondition(input_shape == output_shape);
status = utils::checkCondition(input_shape == output_shape);

if (input->type() != circle::TensorType_INT8 and input->type() != circle::TensorType_INT16)
return status;
}

// Check quantized version
if (input->quantization() == nullptr or output->quantization() == nullptr)
return NoQuantization;

if (output->quantization()->scale() == nullptr or output->quantization()->scale()->size() != 1)
return UnsupportedQuantizationType;

if (input->quantization()->zero_point() == nullptr or
input->quantization()->zero_point()->size() != 1)
return UnsupportedQuantizationType;

return status;
}
43 changes: 2 additions & 41 deletions onert-micro/onert-micro/src/import/kernels/Abs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,12 @@
* limitations under the License.
*/

#include "import/OMKernelConfigureBuilder.h"
#include "core/OMUtils.h"
#include "OMStatus.h"
#include "execute/OMRuntimeKernel.h"
#include "import/helpers/OMConfigureSISOKernel.h"

using namespace onert_micro;
using namespace onert_micro::core;

namespace
{

constexpr uint32_t inputTensorIdx = 0;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

OMStatus onert_micro::import::configure_kernel_CircleAbs(const OMConfigureArgs &config_args)
{
OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;

onert_micro::execute::OMRuntimeKernel runtime_kernel;

OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
if (status != Ok)
return status;

const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
assert(output != nullptr);

status = utils::checkCondition(input->type() == output->type());
if (status != Ok)
return status;

OMRuntimeShape input_shape(input);
OMRuntimeShape output_shape(output);

status = utils::checkCondition(input_shape.dimensionsCount() == output_shape.dimensionsCount());
if (status != Ok)
return status;

status = utils::checkCondition(input_shape.flatSize() == output_shape.flatSize());

return status;
return onert_micro::import::helpers::configure_SISO_kernel(config_args);
}
55 changes: 2 additions & 53 deletions onert-micro/onert-micro/src/import/kernels/Floor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,63 +14,12 @@
* limitations under the License.
*/

#include "import/OMKernelConfigureBuilder.h"
#include "core/OMUtils.h"
#include "OMStatus.h"
#include "execute/OMRuntimeKernel.h"
#include "import/helpers/OMConfigureSISOKernel.h"

using namespace onert_micro;
using namespace onert_micro::core;

namespace
{

constexpr uint32_t inputTensorIdx = 0;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

OMStatus onert_micro::import::configure_kernel_CircleFloor(const OMConfigureArgs &config_args)
{
OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;

onert_micro::execute::OMRuntimeKernel runtime_kernel;

OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
if (status != Ok)
return status;

const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
assert(output != nullptr);

status = utils::checkCondition(input->type() == output->type());
if (status != Ok)
return status;

OMRuntimeShape input_shape(input);
OMRuntimeShape output_shape(output);

// check that input and output dimensions are equal
int N = input_shape.dimensionsCount();
status = utils::checkCondition(N == output_shape.dimensionsCount());
if (status != Ok)
return status;

status = utils::checkCondition(input_shape.flatSize() == output_shape.flatSize());
if (status != Ok)
return status;

// check that sizes of all dimensions are equal
for (int i = 0; i < N; ++i)
{
status = utils::checkCondition(input_shape.dims(i) == output_shape.dims(i));
if (status != Ok)
return status;
}

return status;
return onert_micro::import::helpers::configure_SISO_kernel(config_args);
}
48 changes: 2 additions & 46 deletions onert-micro/onert-micro/src/import/kernels/LogSoftmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,12 @@
* limitations under the License.
*/

#include "OMStatus.h"

#include "core/OMUtils.h"

#include "import/OMKernelConfigureBuilder.h"

#include "execute/OMRuntimeKernel.h"
#include "import/helpers/OMConfigureSISOKernel.h"

using namespace onert_micro;
using namespace onert_micro::core;

namespace
{

constexpr uint32_t inputTensorIdx = 0;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

OMStatus onert_micro::import::configure_kernel_CircleLogSoftmax(const OMConfigureArgs &config_args)
{
OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;

onert_micro::execute::OMRuntimeKernel runtime_kernel;

OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
if (status != Ok)
return status;

const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
assert(output != nullptr);

status = utils::checkCondition(input->type() == output->type());
if (status != Ok)
return status;

OMRuntimeShape input_shape(input);
OMRuntimeShape output_shape(output);

status = utils::checkCondition(input_shape.flatSize() == output_shape.flatSize());
if (status != Ok)
return status;

status = utils::checkCondition(input_shape.dimensionsCount() == output_shape.dimensionsCount());
if (status != Ok)
return status;

return status;
return onert_micro::import::helpers::configure_SISO_kernel(config_args);
}
59 changes: 2 additions & 57 deletions onert-micro/onert-micro/src/import/kernels/Logistic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,67 +14,12 @@
* limitations under the License.
*/

#include "OMStatus.h"

#include "import/OMKernelConfigureBuilder.h"
#include "core/OMUtils.h"
#include "execute/OMRuntimeKernel.h"
#include "import/helpers/OMConfigureSISOKernel.h"

using namespace onert_micro;
using namespace onert_micro::core;

namespace
{

constexpr uint32_t inputTensorIdx = 0;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

OMStatus onert_micro::import::configure_kernel_CircleLogistic(const OMConfigureArgs &config_args)
{
OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;

onert_micro::execute::OMRuntimeKernel runtime_kernel;

OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
if (status != Ok)
return status;

const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
assert(output != nullptr);

status = utils::checkCondition(input->type() == output->type());
if (status != Ok)
return status;

OMRuntimeShape input_shape(input);
OMRuntimeShape output_shape(output);

status = utils::checkCondition(input_shape.dimensionsCount() == output_shape.dimensionsCount());
if (status != Ok)
return status;

status = utils::checkCondition(input_shape.flatSize() == output_shape.flatSize());
if (status != Ok)
return status;

if (input->type() != circle::TensorType_INT8 and input->type() != circle::TensorType_INT16)
return status;

// Check quantized version
if (input->quantization() == nullptr or output->quantization() == nullptr)
return NoQuantization;

if (output->quantization()->scale() == nullptr or output->quantization()->scale()->size() != 1)
return UnsupportedQuantizationType;

if (input->quantization()->scale() == nullptr or input->quantization()->scale()->size() != 1)
return UnsupportedQuantizationType;

return status;
return onert_micro::import::helpers::configure_SISO_kernel(config_args);
}
39 changes: 2 additions & 37 deletions onert-micro/onert-micro/src/import/kernels/Sin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,12 @@
* limitations under the License.
*/

#include "import/OMKernelConfigureBuilder.h"
#include "core/OMUtils.h"
#include "OMStatus.h"
#include "execute/OMRuntimeKernel.h"
#include "import/helpers/OMConfigureSISOKernel.h"

using namespace onert_micro;
using namespace onert_micro::core;

namespace
{

constexpr uint32_t inputTensorIdx = 0;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

OMStatus onert_micro::import::configure_kernel_CircleSin(const OMConfigureArgs &config_args)
{
OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;

onert_micro::execute::OMRuntimeKernel runtime_kernel;

OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
if (status != Ok)
return status;

const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
assert(output != nullptr);

status = utils::checkCondition(input->type() == output->type());
if (status != Ok)
return status;

OMRuntimeShape input_shape(input);
OMRuntimeShape output_shape(output);

status = utils::checkCondition(input_shape == output_shape);

return status;
return onert_micro::import::helpers::configure_SISO_kernel(config_args);
}

0 comments on commit e234798

Please sign in to comment.