diff --git a/runtime/onert/api/nnfw/include/nnfw.h b/runtime/onert/api/nnfw/include/nnfw.h index f3745255dbd..a3a3d093945 100644 --- a/runtime/onert/api/nnfw/include/nnfw.h +++ b/runtime/onert/api/nnfw/include/nnfw.h @@ -213,16 +213,16 @@ NNFW_STATUS nnfw_create_session(nnfw_session **session); NNFW_STATUS nnfw_close_session(nnfw_session *session); /** - * @brief Load model from nnpackage file or directory + * @brief Load model from model file or nnpackage directory * - * The length of \p package_file_path must not exceed 1024 bytes including zero at the end. + * The length of \p file_path must not exceed 1024 bytes including zero at the end. * - * @param[in] session nnfw_session loading the given nnpackage file/dir - * @param[in] package_file_path Path to the nnpackage file or unzipped directory to be loaded + * @param[in] session nnfw_session loading the given file/dir + * @param[in] file_path Path to the model file or nnpackage directory to be loaded * * @return @c NNFW_STATUS_NO_ERROR if successful */ -NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *package_file_path); +NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *file_path); /** * @brief Apply i-th input's tensor info to resize input tensor diff --git a/runtime/onert/api/nnfw/src/nnfw_api.cc b/runtime/onert/api/nnfw/src/nnfw_api.cc index fe2a21ec50d..714d32c920a 100644 --- a/runtime/onert/api/nnfw/src/nnfw_api.cc +++ b/runtime/onert/api/nnfw/src/nnfw_api.cc @@ -78,15 +78,15 @@ NNFW_STATUS nnfw_close_session(nnfw_session *session) /* * Load model from nnpackage file or directory * - * @param session nnfw_session loading the given nnpackage file/dir - * @param package_file_path path to the nnpackage file or unzipped directory to be loaded + * @param session nnfw_session loading the given file/dir + * @param file_path path to the model file or nnpackage directory to be loaded * * @return NNFW_STATUS_NO_ERROR if successful */ -NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *pacakge_file_path) +NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *file_path) { NNFW_RETURN_ERROR_IF_NULL(session); - return session->load_model_from_nnpackage(pacakge_file_path); + return session->load_model_from_path(file_path); } /* diff --git a/runtime/onert/api/nnfw/src/nnfw_internal.cc b/runtime/onert/api/nnfw/src/nnfw_internal.cc index 7208a77587a..7c20066219b 100644 --- a/runtime/onert/api/nnfw/src/nnfw_internal.cc +++ b/runtime/onert/api/nnfw/src/nnfw_internal.cc @@ -48,7 +48,7 @@ NNFW_STATUS nnfw_load_circle_from_buffer(nnfw_session *session, uint8_t *buffer, NNFW_STATUS nnfw_load_model_from_modelfile(nnfw_session *session, const char *file_path) { NNFW_RETURN_ERROR_IF_NULL(session); - return session->load_model_from_modelfile(file_path); + return session->load_model_from_path(file_path); } NNFW_STATUS nnfw_train_export_circleplus(nnfw_session *session, const char *path) diff --git a/runtime/onert/api/nnfw/src/nnfw_session.cc b/runtime/onert/api/nnfw/src/nnfw_session.cc index b2c9c61f4da..fa506c22af3 100644 --- a/runtime/onert/api/nnfw/src/nnfw_session.cc +++ b/runtime/onert/api/nnfw/src/nnfw_session.cc @@ -301,64 +301,42 @@ NNFW_STATUS nnfw_session::load_circle_from_buffer(uint8_t *buffer, size_t size) return NNFW_STATUS_NO_ERROR; } -NNFW_STATUS nnfw_session::load_model_from_modelfile(const char *model_file_path) +NNFW_STATUS nnfw_session::load_model_from_path(const char *path) { if (!isStateInitialized()) return NNFW_STATUS_INVALID_STATE; - if (!model_file_path) + if (!path) { - std::cerr << "Model file path is null." << std::endl; + std::cerr << "Path is null." << std::endl; return NNFW_STATUS_UNEXPECTED_NULL; } - try - { - std::filesystem::path filename{model_file_path}; - if (!filename.has_extension()) - { - std::cerr << "Invalid model file path. Please use file with extension." << std::endl; - return NNFW_STATUS_ERROR; - } - - std::string model_type = filename.extension().string().substr(1); // + 1 to exclude dot - return loadModelFile(filename, model_type); - } - catch (const std::exception &e) + if (!null_terminating(path, MAX_PATH_LENGTH)) { - std::cerr << "Error during model loading : " << e.what() << std::endl; + std::cerr << "Path is too long" << std::endl; return NNFW_STATUS_ERROR; } -} - -NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir) -{ - if (!isStateInitialized()) - return NNFW_STATUS_INVALID_STATE; - if (!package_dir) + try { - std::cerr << "package_dir is null." << std::endl; - return NNFW_STATUS_UNEXPECTED_NULL; - } + std::filesystem::path filename{path}; + if (filename.has_extension()) + { + std::string model_type = filename.extension().string().substr(1); // + 1 to exclude dot + return loadModelFile(filename, model_type); + } - if (!null_terminating(package_dir, MAX_PATH_LENGTH)) - { - std::cerr << "nnpackage path is too long" << std::endl; - return NNFW_STATUS_ERROR; - } + const auto &package_dir = filename; - try - { // TODO : add support for zipped package file load - const std::filesystem::path package_path(package_dir); - if (!std::filesystem::is_directory(package_path)) + if (!std::filesystem::is_directory(package_dir)) { - std::cerr << "invalid nnpackage directory: " << package_path << std::endl; + std::cerr << "invalid path: " << package_dir << std::endl; return NNFW_STATUS_ERROR; } - const auto manifest_file_name = package_path / "metadata/MANIFEST"; + const auto manifest_file_name = package_dir / "metadata/MANIFEST"; std::ifstream mfs(manifest_file_name); // extract the filename of the first(index 0) model @@ -371,7 +349,7 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir) if (!configs.empty() && !configs[0].empty()) { - const auto filepath = package_path / "metadata" / configs[0].asString(); + const auto filepath = package_dir / "metadata" / configs[0].asString(); onert::util::CfgKeyValues keyValues; if (loadConfigure(filepath.string(), keyValues)) @@ -397,7 +375,7 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir) for (uint16_t i = 0; i < num_models; ++i) { - const auto model_file_path = package_path / models[i].asString(); + const auto model_file_path = package_dir / models[i].asString(); const auto model_type = model_types[i].asString(); auto model = loadModel(model_file_path.string(), model_type); if (model == nullptr) @@ -439,12 +417,15 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir) _nnpkg->verify(); _state = State::MODEL_LOADED; + + return NNFW_STATUS_NO_ERROR; } catch (const std::exception &e) { std::cerr << "Error during model loading : " << e.what() << std::endl; return NNFW_STATUS_ERROR; } + return NNFW_STATUS_NO_ERROR; } diff --git a/runtime/onert/api/nnfw/src/nnfw_session.h b/runtime/onert/api/nnfw/src/nnfw_session.h index 910ec71a6c1..efaf7d8b126 100644 --- a/runtime/onert/api/nnfw/src/nnfw_session.h +++ b/runtime/onert/api/nnfw/src/nnfw_session.h @@ -105,7 +105,7 @@ struct nnfw_session public: ~nnfw_session(); - NNFW_STATUS load_model_from_nnpackage(const char *package_file_path); + NNFW_STATUS load_model_from_path(const char *path); NNFW_STATUS prepare(); NNFW_STATUS run(); @@ -139,7 +139,6 @@ struct nnfw_session NNFW_STATUS set_config(const char *key, const char *value); NNFW_STATUS get_config(const char *key, char *value, size_t value_size); NNFW_STATUS load_circle_from_buffer(uint8_t *buffer, size_t size); - NNFW_STATUS load_model_from_modelfile(const char *file_path); // // Experimental API