Skip to content

Commit

Permalink
perf: Improve response throughput of a single gRPC stream (#7404)
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui authored Jul 12, 2024
1 parent d1780d1 commit 3dbf09e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 71 deletions.
5 changes: 4 additions & 1 deletion src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -979,6 +979,9 @@ class InferHandlerState {
// Tracks all the states that have been created on this context.
std::set<InferHandlerStateType*> all_states_;

// Ready to write queue for decoupled
std::queue<InferHandlerStateType*> ready_to_write_states_;

// The step of the entire context.
Steps step_;

Expand Down
196 changes: 127 additions & 69 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,38 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
response->mutable_infer_response()->Clear();
// repopulate the id so that client knows which request failed.
response->mutable_infer_response()->set_id(request.id());
state->step_ = Steps::WRITEREADY;
if (!state->is_decoupled_) {
state->step_ = Steps::WRITEREADY;
state->context_->WriteResponseIfReady(state);
} else {
state->response_queue_->MarkNextResponseComplete();
state->complete_ = true;
state->context_->PutTaskBackToQueue(state);
InferHandler::State* writing_state = nullptr;
std::lock_guard<std::recursive_mutex> lk1(state->context_->mu_);
{
std::lock_guard<std::recursive_mutex> lk2(state->step_mtx_);
state->response_queue_->MarkNextResponseComplete();
state->context_->ready_to_write_states_.push(state);
if (!state->context_->ongoing_write_) {
// Only one write is allowed per gRPC stream / context at any time.
// If the stream is not currently writing, start writing the next
// ready to write response from the next ready to write state from
// 'ready_to_write_states_'. If there are other responses on the
// state ready to be written after starting the write, the state
// will be placed at the back of the 'ready_to_write_states_'. If
// there are no other response, the state will be marked as 'ISSUED'
// if complete final flag is not received yet from the backend or
// completed if complete final flag is received.
// The 'ongoing_write_' will reset once the completion queue returns
// a written state and no additional response on the stream is ready
// to be written.
state->context_->ongoing_write_ = true;
writing_state = state->context_->ready_to_write_states_.front();
state->context_->ready_to_write_states_.pop();
}
state->complete_ = true;
}
if (writing_state != nullptr) {
StateWriteResponse(writing_state);
}
}
}

Expand Down Expand Up @@ -451,7 +476,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
// Decoupled state transitions
//
if (state->step_ == Steps::WRITTEN) {
state->context_->ongoing_write_ = false;
#ifdef TRITON_ENABLE_TRACING
state->trace_timestamps_.emplace_back(
std::make_pair("GRPC_SEND_END", TraceManager::CaptureTimestamp()));
Expand All @@ -469,61 +493,76 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
state->context_->finish_ok_ = false;
}

// Finish the state if all the transactions associated with
// the state have completed.
if (state->IsComplete()) {
state->context_->DecrementRequestCounter();
finished = Finish(state);
} else {
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);

// If there is an available response to be written
// to the stream, then transition directly to WRITEREADY
// state and enqueue itself to the completion queue to be
// taken up later. Otherwise, go to ISSUED state and wait
// for the callback to make a response available.
if (state->response_queue_->HasReadyResponse()) {
state->step_ = Steps::WRITEREADY;
state->context_->PutTaskBackToQueue(state);
} else {
state->step_ = Steps::ISSUED;
{
InferHandler::State* writing_state = nullptr;
std::lock_guard<std::recursive_mutex> lk1(state->context_->mu_);
{
std::lock_guard<std::recursive_mutex> lk2(state->step_mtx_);
if (!state->context_->ready_to_write_states_.empty()) {
writing_state = state->context_->ready_to_write_states_.front();
state->context_->ready_to_write_states_.pop();
} else {
state->context_->ongoing_write_ = false;
}
// Finish the state if all the transactions associated with
// the state have completed.
if (state != writing_state) {
if (state->IsComplete()) {
state->context_->DecrementRequestCounter();
finished = Finish(state);
} else {
state->step_ = Steps::ISSUED;
}
}
}
if (writing_state != nullptr) {
StateWriteResponse(writing_state);
}
}
} else if (state->step_ == Steps::WRITEREADY) {
if (state->delay_response_ms_ != 0) {
// Will delay the write of the response by the specified time.
// This can be used to test the flow where there are other
// responses available to be written.
LOG_INFO << "Delaying the write of the response by "
<< state->delay_response_ms_ << " ms...";
std::this_thread::sleep_for(
std::chrono::milliseconds(state->delay_response_ms_));
}

// Finish the state if all the transactions associated with
// the state have completed.
if (state->IsComplete()) {
state->context_->DecrementRequestCounter();
finished = Finish(state);
} else {
// GRPC doesn't allow to issue another write till
// the notification from previous write has been
// delivered. If there is an ongoing write then
// defer writing and place the task at the back
// of the completion queue to be taken up later.
if (!state->context_->ongoing_write_) {
state->context_->ongoing_write_ = true;
state->context_->DecoupledWriteResponse(state);
} else {
state->context_->PutTaskBackToQueue(state);
}
LOG_ERROR << "Should not print this! Decoupled should NOT write via "
"WRITEREADY!";
// Remove the state from the completion queue
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
state->step_ = Steps::ISSUED;
}
}
}

return !finished;
}

// For decoupled only. Caller must ensure exclusive write.
void
ModelStreamInferHandler::StateWriteResponse(InferHandler::State* state)
{
if (state->delay_response_ms_ != 0) {
// Will delay the write of the response by the specified time.
// This can be used to test the flow where there are other
// responses available to be written.
LOG_INFO << "Delaying the write of the response by "
<< state->delay_response_ms_ << " ms...";
std::this_thread::sleep_for(
std::chrono::milliseconds(state->delay_response_ms_));
}
{
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
state->step_ = Steps::WRITTEN;
// gRPC doesn't allow to issue another write till the notification from
// previous write has been delivered.
state->context_->DecoupledWriteResponse(state);
if (state->response_queue_->HasReadyResponse()) {
state->context_->ready_to_write_states_.push(state);
}
}
}

bool
ModelStreamInferHandler::Finish(InferHandler::State* state)
{
Expand Down Expand Up @@ -701,45 +740,64 @@ ModelStreamInferHandler::StreamInferResponseComplete(
}
}

// Update states to signal that response/error is ready to write to stream
{
if (state->IsGrpcContextCancelled()) {
// Need to hold lock because the handler thread processing context
// cancellation might have cancelled or marked the state for cancellation.
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);

if (state->IsGrpcContextCancelled()) {
LOG_VERBOSE(1)
<< "ModelStreamInferHandler::StreamInferResponseComplete, "
<< state->unique_id_
<< ", skipping writing response because of transaction was cancelled";

// If this was the final callback for the state
// then cycle through the completion queue so
// that state object can be released.
if (is_complete) {
state->step_ = Steps::CANCELLED;
state->context_->PutTaskBackToQueue(state);
}
LOG_VERBOSE(1)
<< "ModelStreamInferHandler::StreamInferResponseComplete, "
<< state->unique_id_
<< ", skipping writing response because of transaction was cancelled";

state->complete_ = is_complete;
return;
// If this was the final callback for the state
// then cycle through the completion queue so
// that state object can be released.
if (is_complete) {
state->step_ = Steps::CANCELLED;
state->context_->PutTaskBackToQueue(state);
}

if (state->is_decoupled_) {
state->complete_ = is_complete;
return;
}

if (state->is_decoupled_) {
InferHandler::State* writing_state = nullptr;
std::lock_guard<std::recursive_mutex> lk1(state->context_->mu_);
{
std::lock_guard<std::recursive_mutex> lk2(state->step_mtx_);
bool has_prev_ready_response = state->response_queue_->HasReadyResponse();
if (response) {
state->response_queue_->MarkNextResponseComplete();
}
if (state->step_ == Steps::ISSUED) {
if (!has_prev_ready_response && response) {
state->context_->ready_to_write_states_.push(state);
}
if (!state->context_->ongoing_write_ &&
!state->context_->ready_to_write_states_.empty()) {
state->context_->ongoing_write_ = true;
writing_state = state->context_->ready_to_write_states_.front();
state->context_->ready_to_write_states_.pop();
}
if (is_complete && state->response_queue_->IsEmpty() &&
state->step_ == Steps::ISSUED) {
// The response queue is empty and complete final flag is received, so
// mark the state as 'WRITEREADY' so it can be cleaned up later.
state->step_ = Steps::WRITEREADY;
state->context_->PutTaskBackToQueue(state);
}
} else {
state->step_ = Steps::WRITEREADY;
if (is_complete) {
state->context_->WriteResponseIfReady(state);
}
state->complete_ = is_complete;
}
if (writing_state != nullptr) {
StateWriteResponse(writing_state);
}
} else { // non-decoupled
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
state->step_ = Steps::WRITEREADY;
if (is_complete) {
state->context_->WriteResponseIfReady(state);
}

state->complete_ = is_complete;
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/grpc/stream_infer_handler.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -112,6 +112,7 @@ class ModelStreamInferHandler
static void StreamInferResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags,
void* userp);
static void StateWriteResponse(InferHandler::State* state);
bool Finish(State* state);

TraceManager* trace_manager_;
Expand Down

0 comments on commit 3dbf09e

Please sign in to comment.