From 1322bd077c0a6cae73f60ac50f190ece44f33dff Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 17 Jun 2024 13:48:42 +0300 Subject: [PATCH] [onert-micro] Extend OMRuntimeGraph api This pr extends OMRuntimeGraph api with tow new methods - getInputDataTypeSize and getOutputDataTypeSize. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- .../onert-micro/include/core/OMRuntimeGraph.h | 3 +++ .../onert-micro/src/core/OMRuntimeGraph.cpp | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/onert-micro/onert-micro/include/core/OMRuntimeGraph.h b/onert-micro/onert-micro/include/core/OMRuntimeGraph.h index af9a6520b56..410f5c78b99 100644 --- a/onert-micro/onert-micro/include/core/OMRuntimeGraph.h +++ b/onert-micro/onert-micro/include/core/OMRuntimeGraph.h @@ -54,6 +54,9 @@ class OMRuntimeGraph void *getInputDataAt(uint32_t position); void *getOutputDataAt(uint32_t position); + size_t getInputDataTypeSize(uint32_t position); + size_t getOutputDataTypeSize(uint32_t position); + OMStatus allocateGraphInputs(); OMStatus reset(); diff --git a/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp b/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp index 0cee20714ca..1d7fb200d24 100644 --- a/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp +++ b/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp @@ -15,6 +15,7 @@ */ #include "core/OMRuntimeGraph.h" +#include "core/OMDataType.h" #include "OMStatus.h" using namespace onert_micro::core; @@ -69,6 +70,25 @@ uint32_t OMRuntimeGraph::getOutputSizeAt(uint32_t position) return shape.flatSize(); } +size_t OMRuntimeGraph::getInputDataTypeSize(uint32_t position) +{ + const auto input_index = _context.getGraphInputTensorIndex(position); + const circle::Tensor *input_tensor = _context.getTensorByIndex(static_cast(input_index)); + + auto type = input_tensor->type(); + return sizeof(OMDataType(type)); +} + +size_t OMRuntimeGraph::getOutputDataTypeSize(uint32_t position) +{ + const auto output_index = _context.getGraphOutputTensorIndex(position); + const circle::Tensor *output_tensor = + _context.getTensorByIndex(static_cast(output_index)); + + auto type = output_tensor->type(); + return sizeof(OMDataType(type)); +} + uint32_t OMRuntimeGraph::getInputSizeAt(uint32_t position) { const auto input_index = _context.getGraphInputTensorIndex(position);