Skip to content

Commit

Permalink
[onert] Use filesystem path for model path (#14404)
Browse files Browse the repository at this point in the history
This commit updates nnfw_api_internal to use filesystem path for model path.

ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <[email protected]>
  • Loading branch information
hseok-oh authored Dec 4, 2024
1 parent 240bda7 commit 0241200
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 48 deletions.
73 changes: 26 additions & 47 deletions runtime/onert/api/nnfw/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@

#include <misc/string_helpers.h>

#include <dirent.h>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
Expand Down Expand Up @@ -244,7 +242,7 @@ nnfw_session::nnfw_session()
: _nnpkg{nullptr}, _coptions{onert::compiler::CompilerOptions::fromGlobalConfig()},
_compiler_artifact{nullptr}, _execution{nullptr}, _kernel_registry{nullptr},
_train_info{nullptr}, _quant_manager{std::make_unique<onert::odc::QuantizeManager>()},
_codegen_manager{std::make_unique<onert::odc::CodegenManager>()}, _model_path{""}
_codegen_manager{std::make_unique<onert::odc::CodegenManager>()}, _model_path{}
{
// DO NOTHING
}
Expand Down Expand Up @@ -350,19 +348,17 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir)
return NNFW_STATUS_ERROR;
}

// TODO : add support for zipped package file load
DIR *dir;
if (!(dir = opendir(package_dir)))
{
std::cerr << "invalid nnpackge directory: " << package_dir << std::endl;
return NNFW_STATUS_ERROR;
}
closedir(dir);

try
{
std::string package_path(package_dir);
std::string manifest_file_name = package_path + "/metadata/MANIFEST";
// TODO : add support for zipped package file load
const std::filesystem::path package_path(package_dir);
if (!std::filesystem::is_directory(package_path))
{
std::cerr << "invalid nnpackage directory: " << package_path << std::endl;
return NNFW_STATUS_ERROR;
}

const auto manifest_file_name = package_path / "metadata/MANIFEST";
std::ifstream mfs(manifest_file_name);

// extract the filename of the first(index 0) model
Expand All @@ -375,10 +371,10 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir)

if (!configs.empty() && !configs[0].empty())
{
auto filepath = package_path + std::string("/metadata/") + configs[0].asString();
const auto filepath = package_path / "metadata" / configs[0].asString();

onert::util::CfgKeyValues keyValues;
if (loadConfigure(filepath, keyValues))
if (loadConfigure(filepath.string(), keyValues))
{
onert::util::setConfigKeyValues(keyValues);
}
Expand All @@ -401,12 +397,12 @@ NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir)

for (uint16_t i = 0; i < num_models; ++i)
{
auto model_file_path = package_path + std::string("/") + models[i].asString();
auto model_type = model_types[i].asString();
auto model = loadModel(model_file_path, model_type);
const auto model_file_path = package_path / models[i].asString();
const auto model_type = model_types[i].asString();
auto model = loadModel(model_file_path.string(), model_type);
if (model == nullptr)
return NNFW_STATUS_ERROR;
_model_path = std::string(model_file_path); // TODO Support multiple models
_model_path = model_file_path; // TODO Support multiple models
model->bindKernelBuilder(_kernel_registry->getBuilder());
_nnpkg->push(onert::ir::ModelIndex{i}, std::move(model));
}
Expand Down Expand Up @@ -985,7 +981,7 @@ NNFW_STATUS nnfw_session::loadModelFile(const std::string &model_file_path,
return NNFW_STATUS_ERROR;

_nnpkg = std::make_shared<onert::ir::NNPkg>(std::move(model));
_model_path = model_file_path;
_model_path = std::filesystem::path(model_file_path);
_compiler_artifact.reset();
_execution.reset();
_train_info = loadTrainingInfo(_nnpkg->primary_model());
Expand Down Expand Up @@ -1658,7 +1654,7 @@ NNFW_STATUS nnfw_session::train_export_circle(const char *path)

try
{
onert::exporter::CircleExporter exporter(_model_path, std::string{path});
onert::exporter::CircleExporter exporter(_model_path.string(), std::string{path});
exporter.updateWeight(_execution);
}
catch (const std::exception &e)
Expand Down Expand Up @@ -1686,7 +1682,7 @@ NNFW_STATUS nnfw_session::train_export_circleplus(const char *path)

try
{
onert::exporter::CircleExporter exporter(_model_path, std::string{path});
onert::exporter::CircleExporter exporter(_model_path.string(), std::string{path});
exporter.updateWeight(_execution);
exporter.updateMetadata(_train_info);
}
Expand Down Expand Up @@ -1854,7 +1850,7 @@ NNFW_STATUS nnfw_session::quantize()
return NNFW_STATUS_INVALID_STATE;
}

auto result = _quant_manager->quantize(_model_path);
auto result = _quant_manager->quantize(_model_path.string());
if (!result)
return NNFW_STATUS_INVALID_STATE;

Expand Down Expand Up @@ -1902,7 +1898,8 @@ NNFW_STATUS nnfw_session::codegen(const char *target, NNFW_CODEGEN_PREF pref)
}

std::string target_str{target};
if (target_str.empty() || target_str.substr(target_str.size() - 4) != "-gen")
if (target_str.empty() || target_str.size() < 5 ||
target_str.substr(target_str.size() - 4) != "-gen")
{
std::cerr << "Error during nnfw_session::codegen : Invalid target" << std::endl;
return NNFW_STATUS_ERROR;
Expand All @@ -1929,38 +1926,20 @@ NNFW_STATUS nnfw_session::codegen(const char *target, NNFW_CODEGEN_PREF pref)
}

assert(_codegen_manager != nullptr);
auto export_model_path = _codegen_manager->exportModelPath();
auto export_model_path = std::filesystem::path(_codegen_manager->exportModelPath());
const auto model_type = target_str.substr(0, target_str.size() - 4);
// If the export_model_path is not set, it generates a compiled model path
// automatically.
if (export_model_path.empty())
{
// model path always has a dot. (valid extension)
auto dotidx = _model_path.rfind('.');
assert(dotidx != std::string::npos);
auto genidx = target_str.rfind("-gen");
assert(genidx != std::string::npos);
// The compiled model path is the same directory of the original model/package with
// target backend extension.
export_model_path = _model_path.substr(0, dotidx + 1) + target_str.substr(0, genidx);
_codegen_manager->exportModelPath(export_model_path);
export_model_path = _model_path.replace_extension(model_type);
_codegen_manager->exportModelPath(export_model_path.string());
}

_codegen_manager->codegen(_model_path, target, codegen_pref);

// Replace model
// TODO Support buffer replace, not file reload
// TODO: Use std::filesystem::path when we can use c++17.
auto dotidx = export_model_path.rfind('.');
if (dotidx == std::string::npos)
{
std::cerr << "Error during nnfw_session::codegen : Invalid compiled model path. Please use a "
"path that includes the extension."
<< std::endl;
return NNFW_STATUS_ERROR;
}

std::string model_type = export_model_path.substr(dotidx + 1); // + 1 to exclude dot

// Replace model
// TODO Support buffer replace, not file reload
return loadModelFile(export_model_path, model_type);
Expand Down
3 changes: 2 additions & 1 deletion runtime/onert/api/nnfw/src/nnfw_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <util/TracingCtx.h>

#include <filesystem>
#include <memory>
#include <string>
#include <thread>
Expand Down Expand Up @@ -211,7 +212,7 @@ struct nnfw_session
// const char *path;
// const uint8 *buf;
// }
std::string _model_path;
std::filesystem::path _model_path;
};

#endif // __API_NNFW_API_INTERNAL_H__

0 comments on commit 0241200

Please sign in to comment.