Skip to content

Commit

Permalink
Add support for sharing an ORT session
Browse files Browse the repository at this point in the history
For every instance in a model instance group a new ORT session is
created. This code adds support to share a session per instance
group.
This support can be enabled by defining 'share_session' to true
in triton model config "parameters". Example:
parameters [
.....
  {
    key: "share_session"
    value: {string_value: "true"}
  }
]

This is a global parameter and cannot be defined per instance
group. The user should determine if the parameter makes sense for
their setup.
  • Loading branch information
samgoel01 authored and quic-suppugun committed Aug 23, 2022
1 parent abc3ee7 commit e000a48
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 36 deletions.
152 changes: 116 additions & 36 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <stdint.h>

#include <mutex>
#include <vector>

Expand Down Expand Up @@ -81,10 +80,10 @@ class ModelState : public BackendModel {
// onnx file, return in 'session' and 'allocator' the ORT session
// and allocator.
TRITONSERVER_Error* LoadModel(
const std::string& artifact_name,
const std::string& artifact_name, const std::string& instance_name,
const TRITONSERVER_InstanceGroupKind instance_group_kind,
const int32_t instance_group_device_id, std::string* model_path,
OrtSession** session, OrtAllocator** default_allocator,
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
cudaStream_t stream);

const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
Expand All @@ -101,6 +100,11 @@ class ModelState : public BackendModel {
TRITONSERVER_Error* AutoCompleteIO(
const char* key, const OnnxTensorInfoMap& io_infos);

TRITONSERVER_Error* GetSessionForGroup(
const std::string& group_name, std::shared_ptr<OrtSession>& session);
TRITONSERVER_Error* SetSessionForGroup(
const std::string& group_name, const std::shared_ptr<OrtSession>& session);

// Session options used when creating a ORT session.
std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;

Expand All @@ -110,6 +114,17 @@ class ModelState : public BackendModel {
// is specified both in the output section and state section, it indicates
// that the backend must return the output state to the client too.
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;

// Indicate if an onnxrt session should be shared or not. This is a model
// global and applies to all instances. So, storing it in the model state
bool share_session_;

// maintain a map of group id to onnx_rt session. This is only useful if
// share_session is set to true in parameters. share_session is a global model
// config and the user should be careful when setting this. There is no way to
// set this per instance group.
std::unordered_map<std::string, std::shared_ptr<OrtSession>>
groupInstanceSessionMap_;
};

TRITONSERVER_Error*
Expand Down Expand Up @@ -188,7 +203,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
}

ModelState::ModelState(TRITONBACKEND_Model* triton_model)
: BackendModel(triton_model)
: BackendModel(triton_model), share_session_(false)
{
// Create session options that will be cloned and used for each
// instance when creating that instance's session.
Expand Down Expand Up @@ -338,20 +353,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
}
}
}

// FIXME. Is it possible to share a single OrtSession across
// multiple instances? If so then should move loading and validation
// of the session to here instead of creating a session for each
// instance in ModelStateInstance::Create().

// This setting will apply across multiple instance groups.
// If this value is set all instances within an instance group will share
// the ort session
{
bool share_session;
triton::common::TritonJson::Value params;
if (ModelConfig().Find("parameters", &params)) {
THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter(
params, "share_session", &share_session, false));
}
share_session_ = share_session;
}
}

TRITONSERVER_Error*
ModelState::LoadModel(
const std::string& artifact_name,
const std::string& artifact_name, const std::string& instance_name,
const TRITONSERVER_InstanceGroupKind instance_group_kind,
const int32_t instance_group_device_id, std::string* model_path,
OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
cudaStream_t stream)
{
// Get the group name for the instance
std::string instance_group_name(GetInstanceGroupName(Name(), instance_name));
// Find the ONNX file that describes the model itself. If the model
// configuration doesn't have an explicit model file specified then
// use the default name ("model.onnx").
Expand All @@ -363,6 +389,10 @@ ModelState::LoadModel(
*model_path = JoinPath(
{RepositoryPath(), std::to_string(Version()), cc_model_filename});

// get default cpu allocator
RETURN_IF_ORT_ERROR(
ort_api->GetAllocatorWithDefaultOptions(default_allocator));

// If the model path is a directory then the actual model is
// <dir>/model.onnx.
{
Expand All @@ -373,6 +403,20 @@ ModelState::LoadModel(
}
}

// Check is we are sharing the session. If so get the session pointer and
// return
if (share_session_) {
if (GetSessionForGroup(instance_group_name, session) == nullptr) {
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Reusing session for group: ") + instance_group_name)
.c_str());
// Return the session
return nullptr;
}
// In case of error carry on with the code
}

{
bool exists;
RETURN_IF_ERROR(FileExists(*model_path, &exists));
Expand Down Expand Up @@ -636,12 +680,22 @@ ModelState::LoadModel(
glock.lock();
}

RETURN_IF_ERROR(OnnxLoader::LoadSession(
true /* is_path */, *model_path, soptions, session));
{
// This will be allocated by OnnxRT here but will be freed when the last
// instance of shared_ptr is released
OrtSession* session_ptr;
RETURN_IF_ERROR(OnnxLoader::LoadSession(
true /* is_path */, *model_path, soptions, &session_ptr));

// get default cpu allocator
RETURN_IF_ORT_ERROR(
ort_api->GetAllocatorWithDefaultOptions(default_allocator));
session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter());

if (share_session_) {
// The session was created fine this is not a critical error
LOG_IF_ERROR(
SetSessionForGroup(instance_group_name, session),
"Failed to map ort session to the group for sharing");
}
}

return nullptr; // success
}
Expand Down Expand Up @@ -685,7 +739,7 @@ ModelState::AutoCompleteConfig()

// Must cleanup 'session'. 'allocator' is default allocator which
// is managed by ONNX Runtime so don't need to free/release
std::unique_ptr<OrtSession, SessionDeleter> session;
std::shared_ptr<OrtSession> session;
OrtAllocator* default_allocator;
std::string model_path;
{
Expand Down Expand Up @@ -714,12 +768,9 @@ ModelState::AutoCompleteConfig()
}
}
#endif // TRITON_ENABLE_GPU

OrtSession* sptr = nullptr;
RETURN_IF_ERROR(LoadModel(
artifact_name, kind, 0, &model_path, &sptr, &default_allocator,
nullptr));
session.reset(sptr);
artifact_name, "", kind, 0, &model_path,
session, &default_allocator, nullptr));
}
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
Expand Down Expand Up @@ -881,6 +932,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
return nullptr; // success
}

TRITONSERVER_Error*
ModelState::GetSessionForGroup(
const std::string& group_name, std::shared_ptr<OrtSession>& session)
{
RETURN_ERROR_IF_TRUE(
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid group name"));
{
std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
sessionEntry;
sessionEntry = groupInstanceSessionMap_.find(group_name);
RETURN_ERROR_IF_TRUE(
(sessionEntry == groupInstanceSessionMap_.end()),
TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group"));

session = sessionEntry->second;
}
return nullptr;
}

TRITONSERVER_Error*
ModelState::SetSessionForGroup(
const std::string& group_name, const std::shared_ptr<OrtSession>& session)
{
RETURN_ERROR_IF_TRUE(
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid group name"));

groupInstanceSessionMap_[group_name] = session;
return nullptr;
}

//
// ModelInstanceState
//
Expand Down Expand Up @@ -967,7 +1050,7 @@ class ModelInstanceState : public BackendModelInstance {

// Onnx Runtime variables that are used across runs on this
// instance.
OrtSession* session_;
std::shared_ptr<OrtSession> session_;
OrtAllocator* default_allocator_;
OrtMemoryInfo* cuda_allocator_info_;
const OrtMemoryInfo* cpu_allocator_info_;
Expand Down Expand Up @@ -1013,7 +1096,7 @@ ModelInstanceState::ModelInstanceState(
io_binding_(nullptr), output_buffer_(nullptr)
{
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_,
ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_,
&default_allocator_, CudaStream()));

if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
Expand All @@ -1026,7 +1109,7 @@ ModelInstanceState::ModelInstanceState(
ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_));

THROW_IF_BACKEND_INSTANCE_ORT_ERROR(
ort_api->CreateIoBinding(session_, &io_binding_));
ort_api->CreateIoBinding(session_.get(), &io_binding_));

THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_));

Expand Down Expand Up @@ -1114,9 +1197,6 @@ ModelInstanceState::~ModelInstanceState()
ort_api->ReleaseRunOptions(runOptions_);
ort_api->ReleaseIoBinding(io_binding_);
ort_api->ReleaseMemoryInfo(cuda_allocator_info_);
if (session_ != nullptr) {
OnnxLoader::UnloadSession(session_);
}
// 'default_allocator_' is default allocator which is managed by ONNX
// Runtime
}
Expand Down Expand Up @@ -1176,7 +1256,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
if (*have_control) {
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos));
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
const auto& iit = input_tensor_infos.find(tensor_name);
if (iit == input_tensor_infos.end()) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -1233,7 +1313,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
if (*have_control) {
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos));
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
const auto& iit = input_tensor_infos.find(tensor_name);
if (iit == input_tensor_infos.end()) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -1280,10 +1360,11 @@ TRITONSERVER_Error*
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
{
std::set<std::string> input_tensor_names;
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names));

OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos));
RETURN_IF_ERROR(
InputInfos(session_.get(), default_allocator_, input_tensor_infos));

if (input_tensor_infos.size() != expected_input_cnt) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -1368,10 +1449,10 @@ TRITONSERVER_Error*
ModelInstanceState::ValidateOutputs()
{
std::set<std::string> output_tensor_names;
RETURN_IF_ERROR(OutputNames(session_, output_tensor_names));
RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names));

RETURN_IF_ERROR(
OutputInfos(session_, default_allocator_, output_tensor_infos_));
OutputInfos(session_.get(), default_allocator_, output_tensor_infos_));

triton::common::TritonJson::Value ios;
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios));
Expand Down Expand Up @@ -1765,7 +1846,7 @@ ModelInstanceState::OrtRun(
const uint32_t response_count)
{
RETURN_IF_ORT_ERROR(
ort_api->RunWithBinding(session_, runOptions_, io_binding_));
ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_));
return nullptr;
}

Expand Down Expand Up @@ -2267,7 +2348,6 @@ ModelInstanceState::ReadOutputTensors(
}
}


} else {
char* output_buffer = nullptr;
RETURN_IF_ORT_ERROR(
Expand Down
17 changes: 17 additions & 0 deletions src/onnxruntime_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,5 +493,22 @@ CompareDimsSupported(
return nullptr; // success
}

std::string
GetInstanceGroupName(
const std::string& model_name, const std::string& instance_name)
{
std::regex groupNameRegex('(' + model_name + '_' + "[0-9]" + ')');
std::smatch groupName;

if (model_name.empty() || instance_name.empty()) {
return "";
}

if (std::regex_search(instance_name, groupName, groupNameRegex)) {
return groupName.str(1);
}

return "";
}

}}} // namespace triton::backend::onnxruntime
4 changes: 4 additions & 0 deletions src/onnxruntime_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#pragma once

#include <onnxruntime_c_api.h>
#include <regex>
#include <set>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -149,4 +150,7 @@ TRITONSERVER_Error* CompareDimsSupported(
const std::vector<int64_t>& model_shape, const std::vector<int64_t>& dims,
const int max_batch_size, const bool compare_exact);

std::string GetInstanceGroupName(
const std::string& model_name, const std::string& instance_name);

}}} // namespace triton::backend::onnxruntime

0 comments on commit e000a48

Please sign in to comment.