From dc8d8c5a33d54c26748a46fccd020a820bf94535 Mon Sep 17 00:00:00 2001 From: ragmani Date: Wed, 8 Jan 2025 06:04:50 +0000 Subject: [PATCH] [onert/python] bind training APIs This commit binds training APIs. - Bind trainings APIs related to session - Bind trainings APIs related to traininfo ONE-DCO-1.0-Signed-off-by: ragmani --- .../python/include/nnfw_session_bindings.h | 3 + .../python/include/nnfw_traininfo_bindings.h | 34 +++++++++ .../src/bindings/nnfw_api_wrapper_pybind.cc | 17 +++++ .../src/bindings/nnfw_session_bindings.cc | 41 +++++++++++ .../src/bindings/nnfw_traininfo_bindings.cc | 73 +++++++++++++++++++ 5 files changed, 168 insertions(+) create mode 100644 runtime/onert/api/python/include/nnfw_traininfo_bindings.h create mode 100644 runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc diff --git a/runtime/onert/api/python/include/nnfw_session_bindings.h b/runtime/onert/api/python/include/nnfw_session_bindings.h index 5052300880c..f5a36c19d64 100644 --- a/runtime/onert/api/python/include/nnfw_session_bindings.h +++ b/runtime/onert/api/python/include/nnfw_session_bindings.h @@ -29,6 +29,9 @@ namespace python // Declare binding common functions void bind_nnfw_session(pybind11::module_ &m); +// Declare binding experimental functions +void bind_experimental_nnfw_session(pybind11::module_ &m); + } // namespace python } // namespace api } // namespace onert diff --git a/runtime/onert/api/python/include/nnfw_traininfo_bindings.h b/runtime/onert/api/python/include/nnfw_traininfo_bindings.h new file mode 100644 index 00000000000..a25e9b81dc5 --- /dev/null +++ b/runtime/onert/api/python/include/nnfw_traininfo_bindings.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__ +#define __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__ + +#include +#include + +namespace py = pybind11; + +// Declare binding train enums +void bind_nnfw_train_enums(py::module_ &m); + +// Declare binding loss info +void bind_nnfw_loss_info(py::module_ &m); + +// Declare binding train info +void bind_nnfw_train_info(py::module_ &m); + +#endif // __ONERT_API_PYTHON_NNFW_TRAININFO_BINDINGS_H__ diff --git a/runtime/onert/api/python/src/bindings/nnfw_api_wrapper_pybind.cc b/runtime/onert/api/python/src/bindings/nnfw_api_wrapper_pybind.cc index f78ee4ec8d3..df31f45f3e3 100644 --- a/runtime/onert/api/python/src/bindings/nnfw_api_wrapper_pybind.cc +++ b/runtime/onert/api/python/src/bindings/nnfw_api_wrapper_pybind.cc @@ -18,6 +18,7 @@ #include "nnfw_session_bindings.h" #include "nnfw_tensorinfo_bindings.h" +#include "nnfw_traininfo_bindings.h" using namespace onert::api::python; @@ -33,6 +34,22 @@ PYBIND11_MODULE(libnnfw_api_pybind, m) auto infer = m.def_submodule("infer", "Inference submodule"); infer.attr("nnfw_session") = m.attr("nnfw_session"); + // Bind experimental `NNFW_SESSION` class + auto experimental = m.def_submodule("experimental", "Experimental submodule"); + experimental.attr("nnfw_session") = m.attr("nnfw_session"); + bind_experimental_nnfw_session(experimental); + // Bind common `tensorinfo` class bind_tensorinfo(m); + + m.doc() = "NNFW Python Bindings for Training"; + + // Bind training enums + bind_nnfw_train_enums(m); + + // Bind training nnfw_loss_info + bind_nnfw_loss_info(m); + + // Bind_train_info + bind_nnfw_train_info(m); } diff --git a/runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc b/runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc index 166ae61d1b5..ca86263cc6f 100644 --- a/runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc +++ b/runtime/onert/api/python/src/bindings/nnfw_session_bindings.cc @@ -225,6 +225,47 @@ void bind_nnfw_session(py::module_ &m) "\ttensorinfo: Tensor info (shape, type, etc)"); } +// Bind the `NNFW_SESSION` class with experimental APIs +void bind_experimental_nnfw_session(py::module_ &m) +{ + // Add experimental APIs for the `NNFW_SESSION` class + m.attr("nnfw_session") + .cast>() + .def("train_get_traininfo", &NNFW_SESSION::train_get_traininfo, + "Retrieve training information for the model.") + .def("train_set_traininfo", &NNFW_SESSION::train_set_traininfo, py::arg("info"), + "Set training information for the model.") + .def("train_prepare", &NNFW_SESSION::train_prepare, "Prepare for training") + .def("train", &NNFW_SESSION::train, py::arg("update_weights") = true, + "Run a training step, optionally updating weights.") + .def("train_get_loss", &NNFW_SESSION::train_get_loss, py::arg("index"), + "Retrieve the training loss for a specific index.") + .def("train_set_input", &NNFW_SESSION::train_set_input, py::arg("index"), + py::arg("buffer"), "Set training input tensor for the given index (float).") + .def("train_set_input", &NNFW_SESSION::train_set_input, py::arg("index"), + py::arg("buffer"), "Set training input tensor for the given index (int).") + .def("train_set_input", &NNFW_SESSION::train_set_input, py::arg("index"), + py::arg("buffer"), "Set training input tensor for the given index (uint8).") + .def("train_set_expected", &NNFW_SESSION::train_set_expected, py::arg("index"), + py::arg("buffer"), "Set expected output tensor for the given index (float).") + .def("train_set_expected", &NNFW_SESSION::train_set_expected, py::arg("index"), + py::arg("buffer"), "Set expected output tensor for the given index (int).") + .def("train_set_expected", &NNFW_SESSION::train_set_expected, py::arg("index"), + py::arg("buffer"), "Set expected output tensor for the given index (uint8).") + .def("train_set_output", &NNFW_SESSION::train_set_output, py::arg("index"), + py::arg("buffer"), "Set output tensor for the given index (float).") + .def("train_set_output", &NNFW_SESSION::train_set_output, py::arg("index"), + py::arg("buffer"), "Set output tensor for the given index (int).") + .def("train_set_output", &NNFW_SESSION::train_set_output, py::arg("index"), + py::arg("buffer"), "Set output tensor for the given index (uint8).") + .def("train_export_circle", &NNFW_SESSION::train_export_circle, py::arg("path"), + "Export the trained model to a circle file.") + .def("train_import_checkpoint", &NNFW_SESSION::train_import_checkpoint, py::arg("path"), + "Import a training checkpoint from a file.") + .def("train_export_checkpoint", &NNFW_SESSION::train_export_checkpoint, py::arg("path"), + "Export the training checkpoint to a file."); +} + } // namespace python } // namespace api } // namespace onert diff --git a/runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc b/runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc new file mode 100644 index 00000000000..9e2411b879e --- /dev/null +++ b/runtime/onert/api/python/src/bindings/nnfw_traininfo_bindings.cc @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnfw_traininfo_bindings.h" + +#include "nnfw_api_wrapper.h" + +namespace py = pybind11; + +using namespace onert::api::python; + +// Declare binding train enums +void bind_nnfw_train_enums(py::module_ &m) +{ + // Bind NNFW_TRAIN_LOSS + py::enum_(m, "loss", py::module_local()) + .value("UNDEFINED", NNFW_TRAIN_LOSS_UNDEFINED) + .value("MEAN_SQUARED_ERROR", NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR) + .value("CATEGORICAL_CROSSENTROPY", NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY); + + // Bind NNFW_TRAIN_LOSS_REDUCTION + py::enum_(m, "loss_reduction", py::module_local()) + .value("UNDEFINED", NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED) + .value("SUM_OVER_BATCH_SIZE", NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE) + .value("SUM", NNFW_TRAIN_LOSS_REDUCTION_SUM); + + // Bind NNFW_TRAIN_OPTIMIZER + py::enum_(m, "optimizer", py::module_local()) + .value("UNDEFINED", NNFW_TRAIN_OPTIMIZER_UNDEFINED) + .value("SGD", NNFW_TRAIN_OPTIMIZER_SGD) + .value("ADAM", NNFW_TRAIN_OPTIMIZER_ADAM); + + // Bind NNFW_TRAIN_NUM_OF_TRAINABLE_OPS_SPECIAL_VALUES + py::enum_(m, "trainable_ops", py::module_local()) + .value("INCORRECT_STATE", NNFW_TRAIN_TRAINABLE_INCORRECT_STATE) + .value("ALL", NNFW_TRAIN_TRAINABLE_ALL) + .value("NONE", NNFW_TRAIN_TRAINABLE_NONE); +} + +// Declare binding loss info +void bind_nnfw_loss_info(py::module_ &m) +{ + py::class_(m, "lossinfo", py::module_local()) + .def(py::init<>()) // Default constructor + .def_readwrite("loss", &nnfw_loss_info::loss, "Loss type") + .def_readwrite("reduction_type", &nnfw_loss_info::reduction_type, "Reduction type"); +} + +// Declare binding train info +void bind_nnfw_train_info(py::module_ &m) +{ + py::class_(m, "traininfo", py::module_local()) + .def(py::init<>()) // Default constructor + .def_readwrite("learning_rate", &nnfw_train_info::learning_rate, "Learning rate") + .def_readwrite("batch_size", &nnfw_train_info::batch_size, "Batch size") + .def_readwrite("loss_info", &nnfw_train_info::loss_info, "Loss information") + .def_readwrite("opt", &nnfw_train_info::opt, "Optimizer type") + .def_readwrite("num_of_trainable_ops", &nnfw_train_info::num_of_trainable_ops, + "Number of trainable operations"); +}