Skip to content

Commit

Permalink
[feature] export partial model loading APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Jul 18, 2024
1 parent a5b99eb commit 9ad9838
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/ppl/nn/models/onnx/runtime_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ class PPLNN_PUBLIC RuntimeBuilder {
virtual ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len,
const char* model_file_dir = nullptr) = 0;

/** @brief load partial model from a file */
virtual ppl::common::RetCode LoadModel(const char* model_file, const char** inputs, uint32_t nr_input,
const char** outputs, uint32_t nr_output) = 0;

/**
@brief load partial model from a buffer
@param model_file_dir used to parse external data. can be nullptr if no external data.
*/
virtual ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len, const char** inputs,
uint32_t nr_input, const char** outputs, uint32_t nr_output,
const char* model_file_dir = nullptr) = 0;

/**
@brief set resources for preprocessing and creating `Runtime`.
MUST be called before `Preprocess()` and `CreateRuntime()`.
Expand Down
31 changes: 31 additions & 0 deletions src/ppl/nn/models/onnx/runtime_builder_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,37 @@ RetCode RuntimeBuilderImpl::LoadModel(const char* model_file) {
return LoadModel(fm.GetData(), fm.GetSize(), parent_dir.c_str());
}

RetCode RuntimeBuilderImpl::LoadModel(const char* model_buf, uint64_t buf_len, const char** inputs, uint32_t nr_input,
const char** outputs, uint32_t nr_output, const char* model_file_dir) {
auto status = ModelParser::Parse(model_buf, buf_len, model_file_dir, inputs, nr_input, outputs, nr_output, &model_);
if (status != RC_SUCCESS) {
LOG(ERROR) << "parse graph failed: " << GetRetCodeStr(status);
return status;
}

return RC_SUCCESS;
}

RetCode RuntimeBuilderImpl::LoadModel(const char* model_file, const char** inputs, uint32_t nr_input,
const char** outputs, uint32_t nr_output) {
Mmap fm;
auto status = fm.Init(model_file, Mmap::READ);
if (status != RC_SUCCESS) {
LOG(ERROR) << "mapping file [" << model_file << "] faild.";
return status;
}

string parent_dir;
auto pos = string(model_file).find_last_of("/\\");
if (pos == string::npos) {
parent_dir = ".";
} else {
parent_dir.assign(model_file, pos);
}

return LoadModel(fm.GetData(), fm.GetSize(), inputs, nr_input, outputs, nr_output, parent_dir.c_str());
}

RetCode RuntimeBuilderImpl::SetResources(const Resources& resource) {
resource_.engines.resize(resource.engine_num);
for (uint32_t i = 0; i < resource.engine_num; ++i) {
Expand Down
5 changes: 5 additions & 0 deletions src/ppl/nn/models/onnx/runtime_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class RuntimeBuilderImpl final : public RuntimeBuilder {
ppl::common::RetCode LoadModel(const char* model_file) override;
ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len,
const char* model_file_dir = nullptr) override;
ppl::common::RetCode LoadModel(const char* model_file, const char** inputs, uint32_t nr_input, const char** outputs,
uint32_t nr_output) override;
ppl::common::RetCode LoadModel(const char* model_buf, uint64_t buf_len, const char** inputs, uint32_t nr_input,
const char** outputs, uint32_t nr_output,
const char* model_file_dir = nullptr) override;
ppl::common::RetCode SetResources(const Resources&) override;
ppl::common::RetCode ReserveTensor(const char* tensor_name) override;
ppl::common::RetCode Preprocess() override;
Expand Down

0 comments on commit 9ad9838

Please sign in to comment.