Skip to content

Commit

Permalink
Keep the model metadata on the generated EP context model (use bridge…
Browse files Browse the repository at this point in the history
… api) (microsoft#22860)

In addition to the
[PR](microsoft#22825) which directly
uses internal graph api, this PR updates the bridge api for the case of
TRT EP and OpenVINO EP.
  • Loading branch information
chilo-ms authored Dec 2, 2024
1 parent 1128882 commit 49a80df
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer,
const bool& ep_context_embed_mode,
std::string&& model_blob_str,
const std::string& openvino_sdk_version) const {
auto model_build = graph_viewer.CreateModel(logger);
auto& metadata = graph_viewer.GetGraph().GetModel().MetaData();
auto model_build = graph_viewer.CreateModel(logger, metadata);
auto& graph_build = model_build->MainGraph();

// Get graph inputs and outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ struct ProviderHost {

// GraphViewer
virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;
virtual std::unique_ptr<Model> GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger) = 0;
virtual std::unique_ptr<Model> GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger, const ModelMetaData&) = 0;

virtual const std::string& GraphViewer__Name(const GraphViewer* p) noexcept = 0;
virtual const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,11 +1022,13 @@ struct Graph final {
PROVIDER_DISALLOW_ALL(Graph)
};

using ModelMetaData = std::unordered_map<std::string, std::string>;

class GraphViewer final {
public:
static void operator delete(void* p) { g_host->GraphViewer__operator_delete(reinterpret_cast<GraphViewer*>(p)); }

std::unique_ptr<Model> CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); }
std::unique_ptr<Model> CreateModel(const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) const { return g_host->GraphViewer__CreateModel(this, logger, metadata); }

const std::string& Name() const noexcept { return g_host->GraphViewer__Name(this); }
const std::filesystem::path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); }
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,

// Serialize modelproto to string
auto new_graph_viewer = graph_build.CreateGraphViewer();
auto model = new_graph_viewer->CreateModel(*logger);
auto& metadata = graph_viewer.GetGraph().GetModel().MetaData();
auto model = new_graph_viewer->CreateModel(*logger, metadata);
auto model_proto = model->ToProto();
new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,8 @@ struct ProviderHostImpl : ProviderHost {

// GraphViewer (wrapped)
void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }
std::unique_ptr<Model> GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger) override {
return std::make_unique<Model>(graph_viewer->Name(), true, ModelMetaData(), PathString(),
std::unique_ptr<Model> GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) override {
return std::make_unique<Model>(graph_viewer->Name(), true, metadata, PathString(),
#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(),
#else
Expand Down

0 comments on commit 49a80df

Please sign in to comment.