Skip to content

Commit

Permalink
[onert] Add wrapping CAPIs for training (#14524)
Browse files Browse the repository at this point in the history
This commit adds wrapping CAPIs for training
  - Add the include guard for nnfw_api_wrapper.h
  - Introduce the namespace onert::api::python
  - Wrap CAPIs for training

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jan 6, 2025
1 parent d539375 commit 765d7b7
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
69 changes: 69 additions & 0 deletions runtime/onert/api/python/include/nnfw_api_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,22 @@
* limitations under the License.
*/

#ifndef __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__
#define __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__

#include "nnfw.h"
#include "nnfw_experimental.h"

#include <pybind11/stl.h>
#include <pybind11/numpy.h>

namespace onert
{
namespace api
{
namespace python
{

namespace py = pybind11;

/**
Expand Down Expand Up @@ -159,4 +170,62 @@ class NNFW_SESSION
void set_output_layout(uint32_t index, const char *layout);
tensorinfo input_tensorinfo(uint32_t index);
tensorinfo output_tensorinfo(uint32_t index);

//////////////////////////////////////////////
// Experimental APIs for training
//////////////////////////////////////////////
nnfw_train_info train_get_traininfo();
void train_set_traininfo(const nnfw_train_info *info);

template <typename T> void train_set_input(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_input_tensorinfo(this->session, index, &tensor_info);

py::buffer_info buf_info = buffer.request();
const auto buf_shape = buf_info.shape;
assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0);
tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0));

ensure_status(nnfw_train_set_input(this->session, index, buffer.request().ptr, &tensor_info));
}
template <typename T> void train_set_expected(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_output_tensorinfo(this->session, index, &tensor_info);

py::buffer_info buf_info = buffer.request();
const auto buf_shape = buf_info.shape;
assert(tensor_info.rank == static_cast<int32_t>(buf_shape.size()) && buf_shape.size() > 0);
tensor_info.dims[0] = static_cast<int32_t>(buf_shape.at(0));

ensure_status(
nnfw_train_set_expected(this->session, index, buffer.request().ptr, &tensor_info));
}
template <typename T> void train_set_output(uint32_t index, py::array_t<T> &buffer)
{
nnfw_tensorinfo tensor_info;
nnfw_output_tensorinfo(this->session, index, &tensor_info);
NNFW_TYPE type = tensor_info.dtype;
uint32_t output_elements = num_elems(&tensor_info);
size_t length = sizeof(T) * output_elements;

ensure_status(nnfw_train_set_output(session, index, type, buffer.request().ptr, length));
}

void train_prepare();
void train(bool update_weights);
float train_get_loss(uint32_t index);

void train_export_circle(const py::str &path);
void train_import_checkpoint(const py::str &path);
void train_export_checkpoint(const py::str &path);

// TODO Add other apis
};

} // namespace python
} // namespace api
} // namespace onert

#endif // __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__
60 changes: 60 additions & 0 deletions runtime/onert/api/python/src/nnfw_api_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@

#include <iostream>

namespace onert
{
namespace api
{
namespace python
{

namespace py = pybind11;

void ensure_status(NNFW_STATUS status)
{
switch (status)
Expand Down Expand Up @@ -243,3 +252,54 @@ tensorinfo NNFW_SESSION::output_tensorinfo(uint32_t index)
}
return ti;
}

//////////////////////////////////////////////
// Experimental APIs for training
//////////////////////////////////////////////
nnfw_train_info NNFW_SESSION::train_get_traininfo()
{
nnfw_train_info train_info = nnfw_train_info();
ensure_status(nnfw_train_get_traininfo(session, &train_info));
return train_info;
}

void NNFW_SESSION::train_set_traininfo(const nnfw_train_info *info)
{
ensure_status(nnfw_train_set_traininfo(session, info));
}

void NNFW_SESSION::train_prepare() { ensure_status(nnfw_train_prepare(session)); }

void NNFW_SESSION::train(bool update_weights)
{
ensure_status(nnfw_train(session, update_weights));
}

float NNFW_SESSION::train_get_loss(uint32_t index)
{
float loss = 0.f;
ensure_status(nnfw_train_get_loss(session, index, &loss));
return loss;
}

void NNFW_SESSION::train_export_circle(const py::str &path)
{
const char *c_str_path = path.cast<std::string>().c_str();
ensure_status(nnfw_train_export_circle(session, c_str_path));
}

void NNFW_SESSION::train_import_checkpoint(const py::str &path)
{
const char *c_str_path = path.cast<std::string>().c_str();
ensure_status(nnfw_train_import_checkpoint(session, c_str_path));
}

void NNFW_SESSION::train_export_checkpoint(const py::str &path)
{
const char *c_str_path = path.cast<std::string>().c_str();
ensure_status(nnfw_train_export_checkpoint(session, c_str_path));
}

} // namespace python
} // namespace api
} // namespace onert
2 changes: 2 additions & 0 deletions runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

namespace py = pybind11;

using namespace onert::api::python;

PYBIND11_MODULE(libnnfw_api_pybind, m)
{
m.doc() = "nnfw python plugin";
Expand Down

0 comments on commit 765d7b7

Please sign in to comment.