From 765d7b71b598f9014a0335bd4278afb5a1ed208f Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Mon, 6 Jan 2025 19:18:36 +0900 Subject: [PATCH] [onert] Add wrapping CAPIs for training (#14524) This commit adds wrapping CAPIs for training - Add the include guard for nnfw_api_wrapper.h - Introduce the namespace onert::api::python - Wrap CAPIs for training ONE-DCO-1.0-Signed-off-by: ragmani --- .../api/python/include/nnfw_api_wrapper.h | 69 +++++++++++++++++++ .../onert/api/python/src/nnfw_api_wrapper.cc | 60 ++++++++++++++++ .../api/python/src/nnfw_api_wrapper_pybind.cc | 2 + 3 files changed, 131 insertions(+) diff --git a/runtime/onert/api/python/include/nnfw_api_wrapper.h b/runtime/onert/api/python/include/nnfw_api_wrapper.h index 23e76b5ce85..6191d49e1af 100644 --- a/runtime/onert/api/python/include/nnfw_api_wrapper.h +++ b/runtime/onert/api/python/include/nnfw_api_wrapper.h @@ -14,11 +14,22 @@ * limitations under the License. */ +#ifndef __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ +#define __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ + #include "nnfw.h" +#include "nnfw_experimental.h" #include #include +namespace onert +{ +namespace api +{ +namespace python +{ + namespace py = pybind11; /** @@ -159,4 +170,62 @@ class NNFW_SESSION void set_output_layout(uint32_t index, const char *layout); tensorinfo input_tensorinfo(uint32_t index); tensorinfo output_tensorinfo(uint32_t index); + + ////////////////////////////////////////////// + // Experimental APIs for training + ////////////////////////////////////////////// + nnfw_train_info train_get_traininfo(); + void train_set_traininfo(const nnfw_train_info *info); + + template void train_set_input(uint32_t index, py::array_t &buffer) + { + nnfw_tensorinfo tensor_info; + nnfw_input_tensorinfo(this->session, index, &tensor_info); + + py::buffer_info buf_info = buffer.request(); + const auto buf_shape = buf_info.shape; + assert(tensor_info.rank == static_cast(buf_shape.size()) && buf_shape.size() > 0); + tensor_info.dims[0] = static_cast(buf_shape.at(0)); + + ensure_status(nnfw_train_set_input(this->session, index, buffer.request().ptr, &tensor_info)); + } + template void train_set_expected(uint32_t index, py::array_t &buffer) + { + nnfw_tensorinfo tensor_info; + nnfw_output_tensorinfo(this->session, index, &tensor_info); + + py::buffer_info buf_info = buffer.request(); + const auto buf_shape = buf_info.shape; + assert(tensor_info.rank == static_cast(buf_shape.size()) && buf_shape.size() > 0); + tensor_info.dims[0] = static_cast(buf_shape.at(0)); + + ensure_status( + nnfw_train_set_expected(this->session, index, buffer.request().ptr, &tensor_info)); + } + template void train_set_output(uint32_t index, py::array_t &buffer) + { + nnfw_tensorinfo tensor_info; + nnfw_output_tensorinfo(this->session, index, &tensor_info); + NNFW_TYPE type = tensor_info.dtype; + uint32_t output_elements = num_elems(&tensor_info); + size_t length = sizeof(T) * output_elements; + + ensure_status(nnfw_train_set_output(session, index, type, buffer.request().ptr, length)); + } + + void train_prepare(); + void train(bool update_weights); + float train_get_loss(uint32_t index); + + void train_export_circle(const py::str &path); + void train_import_checkpoint(const py::str &path); + void train_export_checkpoint(const py::str &path); + + // TODO Add other apis }; + +} // namespace python +} // namespace api +} // namespace onert + +#endif // __ONERT_API_PYTHON_NNFW_API_WRAPPER_H__ diff --git a/runtime/onert/api/python/src/nnfw_api_wrapper.cc b/runtime/onert/api/python/src/nnfw_api_wrapper.cc index 513311fd367..d3f46faf865 100644 --- a/runtime/onert/api/python/src/nnfw_api_wrapper.cc +++ b/runtime/onert/api/python/src/nnfw_api_wrapper.cc @@ -18,6 +18,15 @@ #include +namespace onert +{ +namespace api +{ +namespace python +{ + +namespace py = pybind11; + void ensure_status(NNFW_STATUS status) { switch (status) @@ -243,3 +252,54 @@ tensorinfo NNFW_SESSION::output_tensorinfo(uint32_t index) } return ti; } + +////////////////////////////////////////////// +// Experimental APIs for training +////////////////////////////////////////////// +nnfw_train_info NNFW_SESSION::train_get_traininfo() +{ + nnfw_train_info train_info = nnfw_train_info(); + ensure_status(nnfw_train_get_traininfo(session, &train_info)); + return train_info; +} + +void NNFW_SESSION::train_set_traininfo(const nnfw_train_info *info) +{ + ensure_status(nnfw_train_set_traininfo(session, info)); +} + +void NNFW_SESSION::train_prepare() { ensure_status(nnfw_train_prepare(session)); } + +void NNFW_SESSION::train(bool update_weights) +{ + ensure_status(nnfw_train(session, update_weights)); +} + +float NNFW_SESSION::train_get_loss(uint32_t index) +{ + float loss = 0.f; + ensure_status(nnfw_train_get_loss(session, index, &loss)); + return loss; +} + +void NNFW_SESSION::train_export_circle(const py::str &path) +{ + const char *c_str_path = path.cast().c_str(); + ensure_status(nnfw_train_export_circle(session, c_str_path)); +} + +void NNFW_SESSION::train_import_checkpoint(const py::str &path) +{ + const char *c_str_path = path.cast().c_str(); + ensure_status(nnfw_train_import_checkpoint(session, c_str_path)); +} + +void NNFW_SESSION::train_export_checkpoint(const py::str &path) +{ + const char *c_str_path = path.cast().c_str(); + ensure_status(nnfw_train_export_checkpoint(session, c_str_path)); +} + +} // namespace python +} // namespace api +} // namespace onert diff --git a/runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc b/runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc index 9737f0b58ce..bdea8270c16 100644 --- a/runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc +++ b/runtime/onert/api/python/src/nnfw_api_wrapper_pybind.cc @@ -18,6 +18,8 @@ namespace py = pybind11; +using namespace onert::api::python; + PYBIND11_MODULE(libnnfw_api_pybind, m) { m.doc() = "nnfw python plugin";