Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] pmx: constants can be saved in one file #958

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/ppl/nn/models/pmx/load_model_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ namespace ppl { namespace nn { namespace pmx {
struct PPLNN_PUBLIC LoadModelOptions final {
/** load constant from external data files in `external_data_dir` if `EXTERNAL_MULTI_FILES` is enabled. */
const char* external_data_dir = nullptr;

struct {
const char* external_buffer = nullptr;
uint64_t external_buffer_size = 0;
};
};

}}} // namespace ppl::nn::pmx
Expand Down
3 changes: 3 additions & 0 deletions include/ppl/nn/models/pmx/save_model_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace ppl { namespace nn { namespace pmx {
struct PPLNN_PUBLIC SaveModelOptions final {
/** save constants to external files if not null. one file per constant. */
const char* external_data_dir = nullptr;

/** save constants to one external file if not null. */
const char* external_data_file = nullptr;
};

}}} // namespace ppl::nn::pmx
Expand Down
13 changes: 8 additions & 5 deletions src/ppl/nn/models/pmx/generated/pmx_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,30 @@ struct ModelBuilder;

enum ConstantFlag : uint32_t {
ConstantFlag_EXTERNAL_MULTI_FILES = 1,
ConstantFlag_EXTERNAL_ONE_FILE = 2,
ConstantFlag_MIN = ConstantFlag_EXTERNAL_MULTI_FILES,
ConstantFlag_MAX = ConstantFlag_EXTERNAL_MULTI_FILES
ConstantFlag_MAX = ConstantFlag_EXTERNAL_ONE_FILE
};

inline const ConstantFlag (&EnumValuesConstantFlag())[1] {
inline const ConstantFlag (&EnumValuesConstantFlag())[2] {
static const ConstantFlag values[] = {
ConstantFlag_EXTERNAL_MULTI_FILES
ConstantFlag_EXTERNAL_MULTI_FILES,
ConstantFlag_EXTERNAL_ONE_FILE
};
return values;
}

inline const char * const *EnumNamesConstantFlag() {
static const char * const names[2] = {
static const char * const names[3] = {
"EXTERNAL_MULTI_FILES",
"EXTERNAL_ONE_FILE",
nullptr
};
return names;
}

inline const char *EnumNameConstantFlag(ConstantFlag e) {
if (flatbuffers::IsOutRange(e, ConstantFlag_EXTERNAL_MULTI_FILES, ConstantFlag_EXTERNAL_MULTI_FILES)) return "";
if (flatbuffers::IsOutRange(e, ConstantFlag_EXTERNAL_MULTI_FILES, ConstantFlag_EXTERNAL_ONE_FILE)) return "";
const size_t index = static_cast<size_t>(e) - static_cast<size_t>(ConstantFlag_EXTERNAL_MULTI_FILES);
return EnumNamesConstantFlag()[index];
}
Expand Down
27 changes: 18 additions & 9 deletions src/ppl/nn/models/pmx/graph_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ class PmxConstantVisitor final : public ConstantVisitor {
public:
PmxConstantVisitor(const ir::GraphTopo* topo, const uint8_t* shared_data, const RuntimeGraphInfo* info,
const flatbuffers::Vector<flatbuffers::Offset<ppl::nn::pmx::Constant>>* fb_constants,
const string& external_data_dir)
const LoadModelOptions& opt)
: topo_(topo)
, shared_data_(shared_data)
, info_(info)
, fb_constants_(fb_constants)
, external_data_dir_(external_data_dir) {}
, opt_(opt) {}

RetCode ForEach(const function<RetCode(edgeid_t, uint64_t)>& f) const override {
for (auto fb_constant = fb_constants_->begin(); fb_constant != fb_constants_->end(); ++fb_constant) {
Expand All @@ -164,7 +164,7 @@ class PmxConstantVisitor final : public ConstantVisitor {
}
} else {
uint64_t fsize = 0;
const string path = external_data_dir_ + "/" +
const string path = opt_.external_data_dir + string("/") +
string((const char*)shared_data_ + fb_constant->data_offset(), fb_constant->data_bytes());
auto rc = ppl::nn::utils::GetFileSize(path.c_str(), &fsize);
if (rc != RC_SUCCESS) {
Expand Down Expand Up @@ -200,10 +200,10 @@ class PmxConstantVisitor final : public ConstantVisitor {
if (flags & ConstantFlag_EXTERNAL_MULTI_FILES) {
string path;
if (fb_constant->data_offset() == UINT64_MAX) {
path = external_data_dir_ + "/" +
path = opt_.external_data_dir + string("/") +
utils::GenOutputFileName(topo_->GetEdge(fb_constant->edge_id())->GetName());
} else {
path = external_data_dir_ + "/" +
path = opt_.external_data_dir + string("/") +
string((const char*)shared_data_ + fb_constant->data_offset(), fb_constant->data_bytes());
}

Expand All @@ -219,8 +219,14 @@ class PmxConstantVisitor final : public ConstantVisitor {
return rc;
}
} else {
const uint8_t* shared_data =
(flags & ConstantFlag_EXTERNAL_ONE_FILE) ? (const uint8_t*)opt_.external_buffer : shared_data_;
if (!shared_data_) {
LOG(ERROR) << "constant data buffer is null.";
return RC_INVALID_VALUE;
}
auto rc =
f(edge, shared_data_ + fb_constant->data_offset(), fb_constant->data_bytes(), shape_ref->second);
f(edge, shared_data + fb_constant->data_offset(), fb_constant->data_bytes(), shape_ref->second);
if (rc != RC_SUCCESS) {
LOG(ERROR) << "exec callback for constant[" << edge->GetName() << "] failed: " << GetRetCodeStr(rc);
return rc;
Expand All @@ -235,7 +241,7 @@ class PmxConstantVisitor final : public ConstantVisitor {
const uint8_t* shared_data_;
const RuntimeGraphInfo* info_;
const flatbuffers::Vector<flatbuffers::Offset<ppl::nn::pmx::Constant>>* fb_constants_;
const string& external_data_dir_;
const LoadModelOptions& opt_;
};

static RetCode ParseGraphDataPartitions(const GraphData* fb_data, const ir::GraphTopo* topo,
Expand All @@ -247,8 +253,11 @@ static RetCode ParseGraphDataPartitions(const GraphData* fb_data, const ir::Grap
DeserializationContext deser_ctx;
deser_ctx.shapes = &info->shapes;

const string external_data_dir =
LoadModelOptions internal_opt;
internal_opt.external_data_dir =
(opt.external_data_dir && opt.external_data_dir[0] != '\0') ? opt.external_data_dir : ".";
internal_opt.external_buffer = opt.external_buffer;
internal_opt.external_buffer_size = opt.external_buffer_size;

for (auto fb_partition = fb_partitions->begin(); fb_partition != fb_partitions->end(); ++fb_partition) {
auto engine = seq2engine[fb_partition->engine_id()];
Expand All @@ -257,7 +266,7 @@ static RetCode ParseGraphDataPartitions(const GraphData* fb_data, const ir::Grap
partition.engine = engine;

PmxConstantVisitor visitor(topo, fb_data->shared_data()->data(), info, fb_partition->constants(),
external_data_dir);
internal_opt);
auto status = engine->LoadConstants(visitor, &partition.constants);
if (status != RC_SUCCESS) {
LOG(ERROR) << "LoadConstants of engine[" << engine->GetName() << "] failed: " << GetRetCodeStr(status);
Expand Down
21 changes: 21 additions & 0 deletions src/ppl/nn/models/pmx/runtime_builder_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ RetCode RuntimeBuilderImpl::LoadModel(const char* model_buf, uint64_t buf_len, c
const LoadModelOptions& opt) {
RetCode status;

if (opt.external_data_dir && opt.external_data_dir[0] != '\0') {
if (opt.external_buffer) {
LOG(ERROR) << "only one of `external_data_dir` and `external_buffer` can be set.";
return RC_INVALID_VALUE;
}
}

auto fb_model = pmx::GetModel(model_buf);
if (!fb_model) {
LOG(ERROR) << "parse ppl model failed.";
Expand Down Expand Up @@ -110,6 +117,13 @@ RetCode RuntimeBuilderImpl::LoadModel(const char* model_buf, uint64_t buf_len, c
}

RetCode RuntimeBuilderImpl::LoadModel(const char* model_file, const Resources& resources, const LoadModelOptions& opt) {
if (opt.external_data_dir && opt.external_data_dir[0] != '\0') {
if (opt.external_buffer && opt.external_buffer_size > 0) {
LOG(ERROR) << "only one of `external_data_dir` and `external_buffer` can be set.";
return RC_INVALID_VALUE;
}
}

Mmap fm;
auto status = fm.Init(model_file, Mmap::READ);
if (status != RC_SUCCESS) {
Expand All @@ -123,6 +137,8 @@ RetCode RuntimeBuilderImpl::LoadModel(const char* model_file, const Resources& r
const LoadModelOptions* opt_ptr;
if (opt.external_data_dir && opt.external_data_dir[0] != '\0') {
opt_ptr = &opt;
} else if (opt.external_buffer) {
opt_ptr = &opt;
} else {
new_opt = opt;
opt_ptr = &new_opt;
Expand Down Expand Up @@ -171,6 +187,11 @@ RetCode RuntimeBuilderImpl::Serialize(const char* fmt, const void* options, util
const pmx::SaveModelOptions* opt;
if (options) {
opt = (const pmx::SaveModelOptions*)options;
if ((opt->external_data_dir && opt->external_data_dir[0] != '\0') &&
(opt->external_data_file && opt->external_data_file[0] != '\0')) {
LOG(ERROR) << "only one of `external_data_dir` and `external_data_file` can be set.";
return RC_INVALID_VALUE;
}
} else {
opt = &default_opt;
}
Expand Down
3 changes: 2 additions & 1 deletion src/ppl/nn/models/pmx/schema/pmx.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ table NodeInfo {
}

enum ConstantFlag : uint32 {
EXTERNAL_MULTI_FILES = 1,
EXTERNAL_MULTI_FILES = 0x1,
EXTERNAL_ONE_FILE = 0x10,
}

table Constant {
Expand Down
25 changes: 19 additions & 6 deletions src/ppl/nn/models/pmx/serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ static RetCode CreateFbConstants(FlatBufferBuilder* builder, const SaveModelOpti
return rc;
}
} else {
if (options.external_data_file && options.external_data_file[0] != '\0') {
flags |= ConstantFlag_EXTERNAL_ONE_FILE;
}
offset = FindOrInsertData(data, shared_data, shared_data_items);
}

Expand Down Expand Up @@ -420,13 +423,23 @@ static RetCode CreateFbGraphData(FlatBufferBuilder* builder, const SaveModelOpti
return status;
}

if (shared_data.size() > INT32_MAX) {
LOG(ERROR) << "not supported: size of constant data is larger than 2 GB.";
return RC_UNSUPPORTED;
}
if (options.external_data_file && options.external_data_file[0] != '\0') {
status = SaveData(shared_data, options.external_data_file);
if (status != RC_SUCCESS) {
LOG(ERROR) << "save data to file [" << options.external_data_file << "] failed.";
return status;
}

*fb_data = CreateGraphData(*builder, fb_shapes, fb_partitions, 0);
} else {
if (shared_data.size() > INT32_MAX) {
LOG(ERROR) << "not supported: size of constant data is larger than 2 GB.";
return RC_UNSUPPORTED;
}

auto fb_shared_data = builder->CreateVector(shared_data);
*fb_data = CreateGraphData(*builder, fb_shapes, fb_partitions, fb_shared_data);
auto fb_shared_data = builder->CreateVector(shared_data);
*fb_data = CreateGraphData(*builder, fb_shapes, fb_partitions, fb_shared_data);
}
return RC_SUCCESS;
}

Expand Down
26 changes: 26 additions & 0 deletions tools/pplnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Define_string_opt("--onnx-model", g_flag_onnx_model, "", "onnx model file");
#ifdef PPLNN_ENABLE_PMX_MODEL
Define_string_opt("--pmx-model", g_flag_pmx_model, "", "pmx model file");
Define_string_opt("--pmx-external-data-dir", g_flag_pmx_external_data_dir, "", "dir that contains external data");
Define_string_opt("--pmx-external-data-file", g_flag_pmx_external_data_file, "", "data file that contains all external data");
Define_string_opt("--export-pmx-model", g_flag_export_pmx_model, "", "dump model to <filename> in pmx format");
Define_string_opt("--save-pmx-model", g_flag_save_pmx_model, "", "deprecated. use `--export-pmx-model` instead.");
#endif
Expand Down Expand Up @@ -1244,8 +1245,14 @@ int main(int argc, char* argv[]) {
return -1;
}
pmx::SaveModelOptions opt;
if (!g_flag_pmx_external_data_dir.empty() && !g_flag_pmx_external_data_file.empty()) {
LOG(ERROR) << "only one of `--pmx-external-data-dir` and `pmx-external-data-file` can be set.";
return -1;
}
if (!g_flag_pmx_external_data_dir.empty()) {
opt.external_data_dir = g_flag_pmx_external_data_dir.c_str();
} else if (!g_flag_pmx_external_data_file.empty()) {
opt.external_data_file = g_flag_pmx_external_data_file.c_str();
}
status = builder->Serialize("pmx", &opt, &fds);
if (status != RC_SUCCESS) {
Expand Down Expand Up @@ -1275,9 +1282,22 @@ int main(int argc, char* argv[]) {
resources.engines = engine_ptrs.data();
resources.engine_num = engine_ptrs.size();

Mmap constant_file_buf;
pmx::LoadModelOptions opt;
if (!g_flag_pmx_external_data_dir.empty() && !g_flag_pmx_external_data_file.empty()) {
LOG(ERROR) << "only one of `--pmx-external-data-dir` and `pmx-external-data-file` can be set.";
return -1;
}
if (!g_flag_pmx_external_data_dir.empty()) {
opt.external_data_dir = g_flag_pmx_external_data_dir.c_str();
} else if (!g_flag_pmx_external_data_file.empty()) {
auto rc = constant_file_buf.Init(g_flag_pmx_external_data_file.c_str(), Mmap::READ);
if (rc != RC_SUCCESS) {
LOG(ERROR) << "mmap weight file [" << g_flag_pmx_external_data_file << "] failed.";
return -1;
}
opt.external_buffer = constant_file_buf.GetData();
opt.external_buffer_size = constant_file_buf.GetSize();
}

auto status = builder->LoadModel(g_flag_pmx_model.c_str(), resources, opt);
Expand All @@ -1301,8 +1321,14 @@ int main(int argc, char* argv[]) {
}

pmx::SaveModelOptions opt;
if (!g_flag_pmx_external_data_dir.empty() && !g_flag_pmx_external_data_file.empty()) {
LOG(ERROR) << "only one of `--pmx-external-data-dir` and `pmx-external-data-file` can be set.";
return -1;
}
if (!g_flag_pmx_external_data_dir.empty()) {
opt.external_data_dir = g_flag_pmx_external_data_dir.c_str();
} else if (!g_flag_pmx_external_data_file.empty()) {
opt.external_data_file = g_flag_pmx_external_data_file.c_str();
}

status = builder->Serialize("pmx", &opt, &fds);
Expand Down
Loading