Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sharing an ORT session #1

Merged
merged 1 commit into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 116 additions & 36 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <stdint.h>

#include <mutex>
#include <vector>

Expand Down Expand Up @@ -81,10 +80,10 @@ class ModelState : public BackendModel {
// onnx file, return in 'session' and 'allocator' the ORT session
// and allocator.
TRITONSERVER_Error* LoadModel(
const std::string& artifact_name,
const std::string& artifact_name, const std::string& instance_name,
const TRITONSERVER_InstanceGroupKind instance_group_kind,
const int32_t instance_group_device_id, std::string* model_path,
OrtSession** session, OrtAllocator** default_allocator,
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
cudaStream_t stream);

const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
Expand All @@ -101,6 +100,11 @@ class ModelState : public BackendModel {
TRITONSERVER_Error* AutoCompleteIO(
const char* key, const OnnxTensorInfoMap& io_infos);

TRITONSERVER_Error* GetSessionForGroup(
const std::string& group_name, std::shared_ptr<OrtSession>& session);
TRITONSERVER_Error* SetSessionForGroup(
const std::string& group_name, const std::shared_ptr<OrtSession>& session);

// Session options used when creating a ORT session.
std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;

Expand All @@ -110,6 +114,17 @@ class ModelState : public BackendModel {
// is specified both in the output section and state section, it indicates
// that the backend must return the output state to the client too.
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;

// Indicate if an onnxrt session should be shared or not. This is a model
// global and applies to all instances. So, storing it in the model state
bool share_session_;

// maintain a map of group id to onnx_rt session. This is only useful if
// share_session is set to true in parameters. share_session is a global model
// config and the user should be careful when setting this. There is no way to
// set this per instance group.
std::unordered_map<std::string, std::shared_ptr<OrtSession>>
groupInstanceSessionMap_;
};

TRITONSERVER_Error*
Expand Down Expand Up @@ -188,7 +203,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
}

ModelState::ModelState(TRITONBACKEND_Model* triton_model)
: BackendModel(triton_model)
: BackendModel(triton_model), share_session_(false)
{
// Create session options that will be cloned and used for each
// instance when creating that instance's session.
Expand Down Expand Up @@ -338,20 +353,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
}
}
}

// FIXME. Is it possible to share a single OrtSession across
// multiple instances? If so then should move loading and validation
// of the session to here instead of creating a session for each
// instance in ModelStateInstance::Create().

// This setting will apply across multiple instance groups.
// If this value is set all instances within an instance group will share
// the ort session
{
bool share_session;
triton::common::TritonJson::Value params;
if (ModelConfig().Find("parameters", &params)) {
THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter(
params, "share_session", &share_session, false));
}
share_session_ = share_session;
}
}

TRITONSERVER_Error*
ModelState::LoadModel(
const std::string& artifact_name,
const std::string& artifact_name, const std::string& instance_name,
const TRITONSERVER_InstanceGroupKind instance_group_kind,
const int32_t instance_group_device_id, std::string* model_path,
OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
cudaStream_t stream)
{
// Get the group name for the instance
std::string instance_group_name(GetInstanceGroupName(Name(), instance_name));
// Find the ONNX file that describes the model itself. If the model
// configuration doesn't have an explicit model file specified then
// use the default name ("model.onnx").
Expand All @@ -363,6 +389,10 @@ ModelState::LoadModel(
*model_path = JoinPath(
{RepositoryPath(), std::to_string(Version()), cc_model_filename});

// get default cpu allocator
RETURN_IF_ORT_ERROR(
ort_api->GetAllocatorWithDefaultOptions(default_allocator));

// If the model path is a directory then the actual model is
// <dir>/model.onnx.
{
Expand All @@ -373,6 +403,20 @@ ModelState::LoadModel(
}
}

// Check is we are sharing the session. If so get the session pointer and
// return
if (share_session_) {
if (GetSessionForGroup(instance_group_name, session) == nullptr) {
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Reusing session for group: ") + instance_group_name)
.c_str());
// Return the session
return nullptr;
}
// In case of error carry on with the code
}

{
bool exists;
RETURN_IF_ERROR(FileExists(*model_path, &exists));
Expand Down Expand Up @@ -636,12 +680,22 @@ ModelState::LoadModel(
glock.lock();
}

RETURN_IF_ERROR(OnnxLoader::LoadSession(
true /* is_path */, *model_path, soptions, session));
{
// This will be allocated by OnnxRT here but will be freed when the last
// instance of shared_ptr is released
OrtSession* session_ptr;
RETURN_IF_ERROR(OnnxLoader::LoadSession(
true /* is_path */, *model_path, soptions, &session_ptr));

// get default cpu allocator
RETURN_IF_ORT_ERROR(
ort_api->GetAllocatorWithDefaultOptions(default_allocator));
session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter());

if (share_session_) {
// The session was created fine this is not a critical error
LOG_IF_ERROR(
SetSessionForGroup(instance_group_name, session),
"Failed to map ort session to the group for sharing");
}
}

return nullptr; // success
}
Expand Down Expand Up @@ -685,7 +739,7 @@ ModelState::AutoCompleteConfig()

// Must cleanup 'session'. 'allocator' is default allocator which
// is managed by ONNX Runtime so don't need to free/release
std::unique_ptr<OrtSession, SessionDeleter> session;
std::shared_ptr<OrtSession> session;
OrtAllocator* default_allocator;
std::string model_path;
{
Expand Down Expand Up @@ -714,12 +768,9 @@ ModelState::AutoCompleteConfig()
}
}
#endif // TRITON_ENABLE_GPU

OrtSession* sptr = nullptr;
RETURN_IF_ERROR(LoadModel(
artifact_name, kind, 0, &model_path, &sptr, &default_allocator,
nullptr));
session.reset(sptr);
artifact_name, "", kind, 0, &model_path,
session, &default_allocator, nullptr));
}
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
Expand Down Expand Up @@ -881,6 +932,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
return nullptr; // success
}

TRITONSERVER_Error*
ModelState::GetSessionForGroup(
const std::string& group_name, std::shared_ptr<OrtSession>& session)
{
RETURN_ERROR_IF_TRUE(
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid group name"));
{
std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
sessionEntry;
sessionEntry = groupInstanceSessionMap_.find(group_name);
RETURN_ERROR_IF_TRUE(
(sessionEntry == groupInstanceSessionMap_.end()),
TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group"));

session = sessionEntry->second;
}
return nullptr;
}

TRITONSERVER_Error*
ModelState::SetSessionForGroup(
const std::string& group_name, const std::shared_ptr<OrtSession>& session)
{
RETURN_ERROR_IF_TRUE(
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid group name"));

groupInstanceSessionMap_[group_name] = session;
return nullptr;
}

//
// ModelInstanceState
//
Expand Down Expand Up @@ -967,7 +1050,7 @@ class ModelInstanceState : public BackendModelInstance {

// Onnx Runtime variables that are used across runs on this
// instance.
OrtSession* session_;
std::shared_ptr<OrtSession> session_;
OrtAllocator* default_allocator_;
OrtMemoryInfo* cuda_allocator_info_;
const OrtMemoryInfo* cpu_allocator_info_;
Expand Down Expand Up @@ -1013,7 +1096,7 @@ ModelInstanceState::ModelInstanceState(
io_binding_(nullptr), output_buffer_(nullptr)
{
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_,
ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_,
&default_allocator_, CudaStream()));

if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
Expand All @@ -1026,7 +1109,7 @@ ModelInstanceState::ModelInstanceState(
ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_));

THROW_IF_BACKEND_INSTANCE_ORT_ERROR(
ort_api->CreateIoBinding(session_, &io_binding_));
ort_api->CreateIoBinding(session_.get(), &io_binding_));

THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_));

Expand Down Expand Up @@ -1114,9 +1197,6 @@ ModelInstanceState::~ModelInstanceState()
ort_api->ReleaseRunOptions(runOptions_);
ort_api->ReleaseIoBinding(io_binding_);
ort_api->ReleaseMemoryInfo(cuda_allocator_info_);
if (session_ != nullptr) {
OnnxLoader::UnloadSession(session_);
}
// 'default_allocator_' is default allocator which is managed by ONNX
// Runtime
}
Expand Down Expand Up @@ -1176,7 +1256,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
if (*have_control) {
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos));
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
const auto& iit = input_tensor_infos.find(tensor_name);
if (iit == input_tensor_infos.end()) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -1233,7 +1313,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
if (*have_control) {
OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos));
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
const auto& iit = input_tensor_infos.find(tensor_name);
if (iit == input_tensor_infos.end()) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -1280,10 +1360,11 @@ TRITONSERVER_Error*
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
{
std::set<std::string> input_tensor_names;
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names));

OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos));
RETURN_IF_ERROR(
InputInfos(session_.get(), default_allocator_, input_tensor_infos));

if (input_tensor_infos.size() != expected_input_cnt) {
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -1368,10 +1449,10 @@ TRITONSERVER_Error*
ModelInstanceState::ValidateOutputs()
{
std::set<std::string> output_tensor_names;
RETURN_IF_ERROR(OutputNames(session_, output_tensor_names));
RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names));

RETURN_IF_ERROR(
OutputInfos(session_, default_allocator_, output_tensor_infos_));
OutputInfos(session_.get(), default_allocator_, output_tensor_infos_));

triton::common::TritonJson::Value ios;
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios));
Expand Down Expand Up @@ -1765,7 +1846,7 @@ ModelInstanceState::OrtRun(
const uint32_t response_count)
{
RETURN_IF_ORT_ERROR(
ort_api->RunWithBinding(session_, runOptions_, io_binding_));
ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_));
return nullptr;
}

Expand Down Expand Up @@ -2267,7 +2348,6 @@ ModelInstanceState::ReadOutputTensors(
}
}


} else {
char* output_buffer = nullptr;
RETURN_IF_ORT_ERROR(
Expand Down
17 changes: 17 additions & 0 deletions src/onnxruntime_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,5 +493,22 @@ CompareDimsSupported(
return nullptr; // success
}

std::string
GetInstanceGroupName(
const std::string& model_name, const std::string& instance_name)
{
std::regex groupNameRegex('(' + model_name + '_' + "[0-9]" + ')');
std::smatch groupName;

if (model_name.empty() || instance_name.empty()) {
return "";
}

if (std::regex_search(instance_name, groupName, groupNameRegex)) {
return groupName.str(1);
}

return "";
}

}}} // namespace triton::backend::onnxruntime
4 changes: 4 additions & 0 deletions src/onnxruntime_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#pragma once

#include <onnxruntime_c_api.h>
#include <regex>
#include <set>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -149,4 +150,7 @@ TRITONSERVER_Error* CompareDimsSupported(
const std::vector<int64_t>& model_shape, const std::vector<int64_t>& dims,
const int max_batch_size, const bool compare_exact);

std::string GetInstanceGroupName(
const std::string& model_name, const std::string& instance_name);

}}} // namespace triton::backend::onnxruntime