From e000a481dc04d6cfa78df2bbde585fd71521ed11 Mon Sep 17 00:00:00 2001 From: Sameer Goel Date: Fri, 6 May 2022 21:13:43 -0600 Subject: [PATCH] Add support for sharing an ORT session For every instance in a model instance group a new ORT session is created. This code adds support to share a session per instance group. This support can be enabled by defining 'share_session' to true in triton model config "parameters". Example: parameters [ ..... { key: "share_session" value: {string_value: "true"} } ] This is a global parameter and cannot be defined per instance group. The user should determine if the parameter makes sense for their setup. --- src/onnxruntime.cc | 152 +++++++++++++++++++++++++++++---------- src/onnxruntime_utils.cc | 17 +++++ src/onnxruntime_utils.h | 4 ++ 3 files changed, 137 insertions(+), 36 deletions(-) diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index 790c640..fa405e0 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -25,7 +25,6 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include - #include #include @@ -81,10 +80,10 @@ class ModelState : public BackendModel { // onnx file, return in 'session' and 'allocator' the ORT session // and allocator. TRITONSERVER_Error* LoadModel( - const std::string& artifact_name, + const std::string& artifact_name, const std::string& instance_name, const TRITONSERVER_InstanceGroupKind instance_group_kind, const int32_t instance_group_device_id, std::string* model_path, - OrtSession** session, OrtAllocator** default_allocator, + std::shared_ptr& session, OrtAllocator** default_allocator, cudaStream_t stream); const std::map>& ModelOutputs() @@ -101,6 +100,11 @@ class ModelState : public BackendModel { TRITONSERVER_Error* AutoCompleteIO( const char* key, const OnnxTensorInfoMap& io_infos); + TRITONSERVER_Error* GetSessionForGroup( + const std::string& group_name, std::shared_ptr& session); + TRITONSERVER_Error* SetSessionForGroup( + const std::string& group_name, const std::shared_ptr& session); + // Session options used when creating a ORT session. std::unique_ptr session_options_; @@ -110,6 +114,17 @@ class ModelState : public BackendModel { // is specified both in the output section and state section, it indicates // that the backend must return the output state to the client too. std::map> model_outputs_; + + // Indicate if an onnxrt session should be shared or not. This is a model + // global and applies to all instances. So, storing it in the model state + bool share_session_; + + // maintain a map of group id to onnx_rt session. This is only useful if + // share_session is set to true in parameters. share_session is a global model + // config and the user should be careful when setting this. There is no way to + // set this per instance group. + std::unordered_map> + groupInstanceSessionMap_; }; TRITONSERVER_Error* @@ -188,7 +203,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) } ModelState::ModelState(TRITONBACKEND_Model* triton_model) - : BackendModel(triton_model) + : BackendModel(triton_model), share_session_(false) { // Create session options that will be cloned and used for each // instance when creating that instance's session. @@ -338,20 +353,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } } } - - // FIXME. Is it possible to share a single OrtSession across - // multiple instances? If so then should move loading and validation - // of the session to here instead of creating a session for each - // instance in ModelStateInstance::Create(). + + // This setting will apply across multiple instance groups. + // If this value is set all instances within an instance group will share + // the ort session + { + bool share_session; + triton::common::TritonJson::Value params; + if (ModelConfig().Find("parameters", ¶ms)) { + THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter( + params, "share_session", &share_session, false)); + } + share_session_ = share_session; + } } TRITONSERVER_Error* ModelState::LoadModel( - const std::string& artifact_name, + const std::string& artifact_name, const std::string& instance_name, const TRITONSERVER_InstanceGroupKind instance_group_kind, const int32_t instance_group_device_id, std::string* model_path, - OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream) + std::shared_ptr& session, OrtAllocator** default_allocator, + cudaStream_t stream) { + // Get the group name for the instance + std::string instance_group_name(GetInstanceGroupName(Name(), instance_name)); // Find the ONNX file that describes the model itself. If the model // configuration doesn't have an explicit model file specified then // use the default name ("model.onnx"). @@ -363,6 +389,10 @@ ModelState::LoadModel( *model_path = JoinPath( {RepositoryPath(), std::to_string(Version()), cc_model_filename}); + // get default cpu allocator + RETURN_IF_ORT_ERROR( + ort_api->GetAllocatorWithDefaultOptions(default_allocator)); + // If the model path is a directory then the actual model is // /model.onnx. { @@ -373,6 +403,20 @@ ModelState::LoadModel( } } + // Check is we are sharing the session. If so get the session pointer and + // return + if (share_session_) { + if (GetSessionForGroup(instance_group_name, session) == nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Reusing session for group: ") + instance_group_name) + .c_str()); + // Return the session + return nullptr; + } + // In case of error carry on with the code + } + { bool exists; RETURN_IF_ERROR(FileExists(*model_path, &exists)); @@ -636,12 +680,22 @@ ModelState::LoadModel( glock.lock(); } - RETURN_IF_ERROR(OnnxLoader::LoadSession( - true /* is_path */, *model_path, soptions, session)); + { + // This will be allocated by OnnxRT here but will be freed when the last + // instance of shared_ptr is released + OrtSession* session_ptr; + RETURN_IF_ERROR(OnnxLoader::LoadSession( + true /* is_path */, *model_path, soptions, &session_ptr)); - // get default cpu allocator - RETURN_IF_ORT_ERROR( - ort_api->GetAllocatorWithDefaultOptions(default_allocator)); + session = std::shared_ptr(session_ptr, SessionDeleter()); + + if (share_session_) { + // The session was created fine this is not a critical error + LOG_IF_ERROR( + SetSessionForGroup(instance_group_name, session), + "Failed to map ort session to the group for sharing"); + } + } return nullptr; // success } @@ -685,7 +739,7 @@ ModelState::AutoCompleteConfig() // Must cleanup 'session'. 'allocator' is default allocator which // is managed by ONNX Runtime so don't need to free/release - std::unique_ptr session; + std::shared_ptr session; OrtAllocator* default_allocator; std::string model_path; { @@ -714,12 +768,9 @@ ModelState::AutoCompleteConfig() } } #endif // TRITON_ENABLE_GPU - - OrtSession* sptr = nullptr; RETURN_IF_ERROR(LoadModel( - artifact_name, kind, 0, &model_path, &sptr, &default_allocator, - nullptr)); - session.reset(sptr); + artifact_name, "", kind, 0, &model_path, + session, &default_allocator, nullptr)); } OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( @@ -881,6 +932,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos) return nullptr; // success } +TRITONSERVER_Error* +ModelState::GetSessionForGroup( + const std::string& group_name, std::shared_ptr& session) +{ + RETURN_ERROR_IF_TRUE( + group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG, + std::string("Invalid group name")); + { + std::unordered_map>::iterator + sessionEntry; + sessionEntry = groupInstanceSessionMap_.find(group_name); + RETURN_ERROR_IF_TRUE( + (sessionEntry == groupInstanceSessionMap_.end()), + TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group")); + + session = sessionEntry->second; + } + return nullptr; +} + +TRITONSERVER_Error* +ModelState::SetSessionForGroup( + const std::string& group_name, const std::shared_ptr& session) +{ + RETURN_ERROR_IF_TRUE( + group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG, + std::string("Invalid group name")); + + groupInstanceSessionMap_[group_name] = session; + return nullptr; +} + // // ModelInstanceState // @@ -967,7 +1050,7 @@ class ModelInstanceState : public BackendModelInstance { // Onnx Runtime variables that are used across runs on this // instance. - OrtSession* session_; + std::shared_ptr session_; OrtAllocator* default_allocator_; OrtMemoryInfo* cuda_allocator_info_; const OrtMemoryInfo* cpu_allocator_info_; @@ -1013,7 +1096,7 @@ ModelInstanceState::ModelInstanceState( io_binding_(nullptr), output_buffer_(nullptr) { THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel( - ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_, + ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_, &default_allocator_, CudaStream())); if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { @@ -1026,7 +1109,7 @@ ModelInstanceState::ModelInstanceState( ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_)); THROW_IF_BACKEND_INSTANCE_ORT_ERROR( - ort_api->CreateIoBinding(session_, &io_binding_)); + ort_api->CreateIoBinding(session_.get(), &io_binding_)); THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_)); @@ -1114,9 +1197,6 @@ ModelInstanceState::~ModelInstanceState() ort_api->ReleaseRunOptions(runOptions_); ort_api->ReleaseIoBinding(io_binding_); ort_api->ReleaseMemoryInfo(cuda_allocator_info_); - if (session_ != nullptr) { - OnnxLoader::UnloadSession(session_); - } // 'default_allocator_' is default allocator which is managed by ONNX // Runtime } @@ -1176,7 +1256,7 @@ ModelInstanceState::ValidateBooleanSequenceControl( if (*have_control) { OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos)); const auto& iit = input_tensor_infos.find(tensor_name); if (iit == input_tensor_infos.end()) { return TRITONSERVER_ErrorNew( @@ -1233,7 +1313,7 @@ ModelInstanceState::ValidateTypedSequenceControl( if (*have_control) { OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos)); const auto& iit = input_tensor_infos.find(tensor_name); if (iit == input_tensor_infos.end()) { return TRITONSERVER_ErrorNew( @@ -1280,10 +1360,11 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) { std::set input_tensor_names; - RETURN_IF_ERROR(InputNames(session_, input_tensor_names)); + RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names)); OnnxTensorInfoMap input_tensor_infos; - RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos)); + RETURN_IF_ERROR( + InputInfos(session_.get(), default_allocator_, input_tensor_infos)); if (input_tensor_infos.size() != expected_input_cnt) { return TRITONSERVER_ErrorNew( @@ -1368,10 +1449,10 @@ TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() { std::set output_tensor_names; - RETURN_IF_ERROR(OutputNames(session_, output_tensor_names)); + RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names)); RETURN_IF_ERROR( - OutputInfos(session_, default_allocator_, output_tensor_infos_)); + OutputInfos(session_.get(), default_allocator_, output_tensor_infos_)); triton::common::TritonJson::Value ios; RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios)); @@ -1765,7 +1846,7 @@ ModelInstanceState::OrtRun( const uint32_t response_count) { RETURN_IF_ORT_ERROR( - ort_api->RunWithBinding(session_, runOptions_, io_binding_)); + ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_)); return nullptr; } @@ -2267,7 +2348,6 @@ ModelInstanceState::ReadOutputTensors( } } - } else { char* output_buffer = nullptr; RETURN_IF_ORT_ERROR( diff --git a/src/onnxruntime_utils.cc b/src/onnxruntime_utils.cc index e46532b..96528cb 100644 --- a/src/onnxruntime_utils.cc +++ b/src/onnxruntime_utils.cc @@ -493,5 +493,22 @@ CompareDimsSupported( return nullptr; // success } +std::string +GetInstanceGroupName( + const std::string& model_name, const std::string& instance_name) +{ + std::regex groupNameRegex('(' + model_name + '_' + "[0-9]" + ')'); + std::smatch groupName; + + if (model_name.empty() || instance_name.empty()) { + return ""; + } + + if (std::regex_search(instance_name, groupName, groupNameRegex)) { + return groupName.str(1); + } + + return ""; +} }}} // namespace triton::backend::onnxruntime diff --git a/src/onnxruntime_utils.h b/src/onnxruntime_utils.h index f42bf33..cc0d481 100644 --- a/src/onnxruntime_utils.h +++ b/src/onnxruntime_utils.h @@ -27,6 +27,7 @@ #pragma once #include +#include #include #include #include @@ -149,4 +150,7 @@ TRITONSERVER_Error* CompareDimsSupported( const std::vector& model_shape, const std::vector& dims, const int max_batch_size, const bool compare_exact); +std::string GetInstanceGroupName( + const std::string& model_name, const std::string& instance_name); + }}} // namespace triton::backend::onnxruntime