From 9ad9838b988383d4c8a5a1e74e4e21db1348a179 Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Thu, 18 Jul 2024 16:16:55 +0800 Subject: [PATCH] [feature] export partial model loading APIs --- include/ppl/nn/models/onnx/runtime_builder.h | 12 +++++++ .../nn/models/onnx/runtime_builder_impl.cc | 31 +++++++++++++++++++ src/ppl/nn/models/onnx/runtime_builder_impl.h | 5 +++ 3 files changed, 48 insertions(+) diff --git a/include/ppl/nn/models/onnx/runtime_builder.h b/include/ppl/nn/models/onnx/runtime_builder.h index 9943d6829..9ec30efc9 100644 --- a/include/ppl/nn/models/onnx/runtime_builder.h +++ b/include/ppl/nn/models/onnx/runtime_builder.h @@ -46,6 +46,18 @@ class PPLNN_PUBLIC RuntimeBuilder { virtual ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len, const char* model_file_dir = nullptr) = 0; + /** @brief load partial model from a file */ + virtual ppl::common::RetCode LoadModel(const char* model_file, const char** inputs, uint32_t nr_input, + const char** outputs, uint32_t nr_output) = 0; + + /** + @brief load partial model from a buffer + @param model_file_dir used to parse external data. can be nullptr if no external data. + */ + virtual ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len, const char** inputs, + uint32_t nr_input, const char** outputs, uint32_t nr_output, + const char* model_file_dir = nullptr) = 0; + /** @brief set resources for preprocessing and creating `Runtime`. MUST be called before `Preprocess()` and `CreateRuntime()`. diff --git a/src/ppl/nn/models/onnx/runtime_builder_impl.cc b/src/ppl/nn/models/onnx/runtime_builder_impl.cc index 8608e2920..16a5685ff 100644 --- a/src/ppl/nn/models/onnx/runtime_builder_impl.cc +++ b/src/ppl/nn/models/onnx/runtime_builder_impl.cc @@ -71,6 +71,37 @@ RetCode RuntimeBuilderImpl::LoadModel(const char* model_file) { return LoadModel(fm.GetData(), fm.GetSize(), parent_dir.c_str()); } +RetCode RuntimeBuilderImpl::LoadModel(const char* model_buf, uint64_t buf_len, const char** inputs, uint32_t nr_input, + const char** outputs, uint32_t nr_output, const char* model_file_dir) { + auto status = ModelParser::Parse(model_buf, buf_len, model_file_dir, inputs, nr_input, outputs, nr_output, &model_); + if (status != RC_SUCCESS) { + LOG(ERROR) << "parse graph failed: " << GetRetCodeStr(status); + return status; + } + + return RC_SUCCESS; +} + +RetCode RuntimeBuilderImpl::LoadModel(const char* model_file, const char** inputs, uint32_t nr_input, + const char** outputs, uint32_t nr_output) { + Mmap fm; + auto status = fm.Init(model_file, Mmap::READ); + if (status != RC_SUCCESS) { + LOG(ERROR) << "mapping file [" << model_file << "] faild."; + return status; + } + + string parent_dir; + auto pos = string(model_file).find_last_of("/\\"); + if (pos == string::npos) { + parent_dir = "."; + } else { + parent_dir.assign(model_file, pos); + } + + return LoadModel(fm.GetData(), fm.GetSize(), inputs, nr_input, outputs, nr_output, parent_dir.c_str()); +} + RetCode RuntimeBuilderImpl::SetResources(const Resources& resource) { resource_.engines.resize(resource.engine_num); for (uint32_t i = 0; i < resource.engine_num; ++i) { diff --git a/src/ppl/nn/models/onnx/runtime_builder_impl.h b/src/ppl/nn/models/onnx/runtime_builder_impl.h index 51e7d4fc1..d53540180 100644 --- a/src/ppl/nn/models/onnx/runtime_builder_impl.h +++ b/src/ppl/nn/models/onnx/runtime_builder_impl.h @@ -34,6 +34,11 @@ class RuntimeBuilderImpl final : public RuntimeBuilder { ppl::common::RetCode LoadModel(const char* model_file) override; ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len, const char* model_file_dir = nullptr) override; + ppl::common::RetCode LoadModel(const char* model_file, const char** inputs, uint32_t nr_input, const char** outputs, + uint32_t nr_output) override; + ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len, const char** inputs, uint32_t nr_input, + const char** outputs, uint32_t nr_output, + const char* model_file_dir = nullptr) override; ppl::common::RetCode SetResources(const Resources&) override; ppl::common::RetCode ReserveTensor(const char* tensor_name) override; ppl::common::RetCode Preprocess() override;