Skip to content

Commit

Permalink
[onert-micro] Extend OMRuntimeGraph api
Browse files Browse the repository at this point in the history
This pr extends OMRuntimeGraph api with tow new methods - getInputDataTypeSize and getOutputDataTypeSize.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Jun 17, 2024
1 parent 6482288 commit 1322bd0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions onert-micro/onert-micro/include/core/OMRuntimeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class OMRuntimeGraph
void *getInputDataAt(uint32_t position);
void *getOutputDataAt(uint32_t position);

size_t getInputDataTypeSize(uint32_t position);
size_t getOutputDataTypeSize(uint32_t position);

OMStatus allocateGraphInputs();

OMStatus reset();
Expand Down
20 changes: 20 additions & 0 deletions onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "core/OMRuntimeGraph.h"
#include "core/OMDataType.h"
#include "OMStatus.h"

using namespace onert_micro::core;
Expand Down Expand Up @@ -69,6 +70,25 @@ uint32_t OMRuntimeGraph::getOutputSizeAt(uint32_t position)
return shape.flatSize();
}

size_t OMRuntimeGraph::getInputDataTypeSize(uint32_t position)
{
const auto input_index = _context.getGraphInputTensorIndex(position);
const circle::Tensor *input_tensor = _context.getTensorByIndex(static_cast<int32_t>(input_index));

auto type = input_tensor->type();
return sizeof(OMDataType(type));
}

size_t OMRuntimeGraph::getOutputDataTypeSize(uint32_t position)
{
const auto output_index = _context.getGraphOutputTensorIndex(position);
const circle::Tensor *output_tensor =
_context.getTensorByIndex(static_cast<int32_t>(output_index));

auto type = output_tensor->type();
return sizeof(OMDataType(type));
}

uint32_t OMRuntimeGraph::getInputSizeAt(uint32_t position)
{
const auto input_index = _context.getGraphInputTensorIndex(position);
Expand Down

0 comments on commit 1322bd0

Please sign in to comment.