Skip to content

Commit

Permalink
ResponseQueue Threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
pskiran1 committed Jan 21, 2025
1 parent c9bcbd9 commit 2fba1dd
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 15 deletions.
14 changes: 14 additions & 0 deletions src/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ enum TritonOptionId {
OPTION_GRPC_ADDRESS,
OPTION_GRPC_HEADER_FORWARD_PATTERN,
OPTION_GRPC_INFER_ALLOCATION_POOL_SIZE,
OPTION_GRPC_MAX_RESPONSE_POOL_SIZE,
OPTION_GRPC_USE_SSL,
OPTION_GRPC_USE_SSL_MUTUAL,
OPTION_GRPC_SERVER_CERT,
Expand Down Expand Up @@ -536,6 +537,11 @@ TritonParser::SetupOptions()
"allocated for reuse. As long as the number of in-flight requests "
"doesn't exceed this value there will be no allocation/deallocation of "
"request/response objects."});
grpc_options_.push_back(
{OPTION_GRPC_MAX_RESPONSE_POOL_SIZE, "grpc-max-response-pool-size",
Option::ArgInt,
"The maximum number of inference response objects that can remain "
"allocated in the pool at any given time."});
grpc_options_.push_back(
{OPTION_GRPC_USE_SSL, "grpc-use-ssl", Option::ArgBool,
"Use SSL authentication for GRPC requests. Default is false."});
Expand Down Expand Up @@ -1438,6 +1444,14 @@ TritonParser::Parse(int argc, char** argv)
case OPTION_GRPC_INFER_ALLOCATION_POOL_SIZE:
lgrpc_options.infer_allocation_pool_size_ = ParseOption<int>(optarg);
break;
case OPTION_GRPC_MAX_RESPONSE_POOL_SIZE:
lgrpc_options.max_response_pool_size_ = ParseOption<int>(optarg);
if (lgrpc_options.max_response_pool_size_ <= 0) {
throw ParseException(
"Error: --grpc-max-response-pool-size must be greater "
"than 0.");
}
break;
case OPTION_GRPC_USE_SSL:
lgrpc_options.ssl_.use_ssl_ = ParseOption<bool>(optarg);
break;
Expand Down
10 changes: 6 additions & 4 deletions src/grpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2395,8 +2395,8 @@ Server::Server(
"ModelInferHandler", tritonserver_, trace_manager_, shm_manager_,
&service_, model_infer_cq_.get(),
options.infer_allocation_pool_size_ /* max_state_bucket_count */,
options.infer_compression_level_, restricted_kv,
options.forward_header_pattern_));
options.max_response_pool_size_, options.infer_compression_level_,
restricted_kv, options.forward_header_pattern_));
}

// Handler for streaming inference requests. Keeps one handler for streaming
Expand All @@ -2405,8 +2405,8 @@ Server::Server(
"ModelStreamInferHandler", tritonserver_, trace_manager_, shm_manager_,
&service_, model_stream_infer_cq_.get(),
options.infer_allocation_pool_size_ /* max_state_bucket_count */,
options.infer_compression_level_, restricted_kv,
options.forward_header_pattern_));
options.max_response_pool_size_, options.infer_compression_level_,
restricted_kv, options.forward_header_pattern_));
}

Server::~Server()
Expand Down Expand Up @@ -2472,6 +2472,8 @@ Server::GetOptions(Options& options, UnorderedMapType& options_map)
RETURN_IF_ERR(GetValue(
options_map, "infer_allocation_pool_size",
&options.infer_allocation_pool_size_));
RETURN_IF_ERR(GetValue(
options_map, "max_response_pool_size", &options.max_response_pool_size_));
RETURN_IF_ERR(GetValue(
options_map, "forward_header_pattern", &options.forward_header_pattern_));

Expand Down
1 change: 1 addition & 0 deletions src/grpc/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct Options {
// requests doesn't exceed this value there will be no
// allocation/deallocation of request/response objects.
int infer_allocation_pool_size_{8};
int max_response_pool_size_{INT_MAX};
RestrictedFeatures restricted_protocols_;
std::string forward_header_pattern_;
};
Expand Down
32 changes: 23 additions & 9 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ struct RequestReleasePayload final {
template <typename ResponseType>
class ResponseQueue {
public:
explicit ResponseQueue() { Reset(); }
explicit ResponseQueue(const size_t max_response_queue_size)
: max_response_queue_size_(max_response_queue_size)
{
Reset();
}

~ResponseQueue()
{
Expand Down Expand Up @@ -160,7 +164,9 @@ class ResponseQueue {
// Allocates a response at the end of the queue
void AllocateResponse()
{
std::lock_guard<std::mutex> lock(mtx_);
std::unique_lock<std::mutex> lock(mtx_);
cv_.wait(
lock, [this] { return responses_.size() < max_response_queue_size_; });
alloc_count_++;

// Use a response from the reusable pool if available
Expand Down Expand Up @@ -257,6 +263,8 @@ class ResponseQueue {
reusable_pool_.push_back(response);
responses_.pop_front();
pop_count_++;

cv_.notify_one();
}

// Returns whether the queue is empty
Expand All @@ -282,6 +290,8 @@ class ResponseQueue {
std::deque<ResponseType*> responses_;
// Stores completed responses that can be reused
std::deque<ResponseType*> reusable_pool_;
std::condition_variable cv_;
size_t max_response_queue_size_;
std::mutex mtx_;

// Three counters are used to track and manage responses in the queue
Expand Down Expand Up @@ -1122,7 +1132,7 @@ class InferHandlerState {
}

explicit InferHandlerState(
TRITONSERVER_Server* tritonserver,
TRITONSERVER_Server* tritonserver, const size_t max_response_queue_size,
const std::shared_ptr<Context>& context, Steps start_step = Steps::START)
: tritonserver_(tritonserver), async_notify_state_(false)
{
Expand All @@ -1136,7 +1146,8 @@ class InferHandlerState {
delay_response_completion_ms_ =
ParseDebugVariable("TRITONSERVER_DELAY_RESPONSE_COMPLETION");

response_queue_.reset(new ResponseQueue<ResponseType>());
response_queue_.reset(
new ResponseQueue<ResponseType>(max_response_queue_size));
Reset(context, start_step);
}

Expand Down Expand Up @@ -1289,7 +1300,7 @@ class InferHandler : public HandlerBase {
const std::string& name,
const std::shared_ptr<TRITONSERVER_Server>& tritonserver,
ServiceType* service, ::grpc::ServerCompletionQueue* cq,
size_t max_state_bucket_count,
size_t max_state_bucket_count, size_t max_response_queue_size,
std::pair<std::string, std::string> restricted_kv,
const std::string& header_forward_pattern);
virtual ~InferHandler();
Expand Down Expand Up @@ -1326,7 +1337,8 @@ class InferHandler : public HandlerBase {
}

if (state == nullptr) {
state = new State(tritonserver, context, start_step);
state = new State(
tritonserver, max_response_queue_size_, context, start_step);
}

if (start_step == Steps::START) {
Expand Down Expand Up @@ -1427,6 +1439,7 @@ class InferHandler : public HandlerBase {
const size_t max_state_bucket_count_;
std::vector<State*> state_bucket_;

const size_t max_response_queue_size_;
std::pair<std::string, std::string> restricted_kv_;
std::string header_forward_pattern_;
re2::RE2 header_forward_regex_;
Expand All @@ -1440,11 +1453,12 @@ InferHandler<ServiceType, ServerResponderType, RequestType, ResponseType>::
const std::string& name,
const std::shared_ptr<TRITONSERVER_Server>& tritonserver,
ServiceType* service, ::grpc::ServerCompletionQueue* cq,
size_t max_state_bucket_count,
size_t max_state_bucket_count, size_t max_response_queue_size,
std::pair<std::string, std::string> restricted_kv,
const std::string& header_forward_pattern)
: name_(name), tritonserver_(tritonserver), service_(service), cq_(cq),
max_state_bucket_count_(max_state_bucket_count),
max_response_queue_size_(max_response_queue_size),
restricted_kv_(restricted_kv),
header_forward_pattern_(header_forward_pattern),
header_forward_regex_(header_forward_pattern_)
Expand Down Expand Up @@ -1600,12 +1614,12 @@ class ModelInferHandler
const std::shared_ptr<SharedMemoryManager>& shm_manager,
inference::GRPCInferenceService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count,
grpc_compression_level compression_level,
size_t max_response_queue_size, grpc_compression_level compression_level,
std::pair<std::string, std::string> restricted_kv,
const std::string& forward_header_pattern)
: InferHandler(
name, tritonserver, service, cq, max_state_bucket_count,
restricted_kv, forward_header_pattern),
max_response_queue_size, restricted_kv, forward_header_pattern),
trace_manager_(trace_manager), shm_manager_(shm_manager),
compression_level_(compression_level)
{
Expand Down
4 changes: 2 additions & 2 deletions src/grpc/stream_infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ class ModelStreamInferHandler
const std::shared_ptr<SharedMemoryManager>& shm_manager,
inference::GRPCInferenceService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count,
grpc_compression_level compression_level,
size_t max_response_queue_size, grpc_compression_level compression_level,
std::pair<std::string, std::string> restricted_kv,
const std::string& header_forward_pattern)
: InferHandler(
name, tritonserver, service, cq, max_state_bucket_count,
restricted_kv, header_forward_pattern),
max_response_queue_size, restricted_kv, header_forward_pattern),
trace_manager_(trace_manager), shm_manager_(shm_manager),
compression_level_(compression_level)
{
Expand Down

0 comments on commit 2fba1dd

Please sign in to comment.