diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 6ef03807a2..0e1091feb8 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -979,6 +979,9 @@ class InferHandlerState { // Tracks all the states that have been created on this context. std::set all_states_; + // Ready to write queue for decoupled + std::queue ready_to_write_states_; + // The step of the entire context. Steps step_; diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 269808c78a..585f88d536 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -359,13 +359,38 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) response->mutable_infer_response()->Clear(); // repopulate the id so that client knows which request failed. response->mutable_infer_response()->set_id(request.id()); - state->step_ = Steps::WRITEREADY; if (!state->is_decoupled_) { + state->step_ = Steps::WRITEREADY; state->context_->WriteResponseIfReady(state); } else { - state->response_queue_->MarkNextResponseComplete(); - state->complete_ = true; - state->context_->PutTaskBackToQueue(state); + InferHandler::State* writing_state = nullptr; + std::lock_guard lk1(state->context_->mu_); + { + std::lock_guard lk2(state->step_mtx_); + state->response_queue_->MarkNextResponseComplete(); + state->context_->ready_to_write_states_.push(state); + if (!state->context_->ongoing_write_) { + // Only one write is allowed per gRPC stream / context at any time. + // If the stream is not currently writing, start writing the next + // ready to write response from the next ready to write state from + // 'ready_to_write_states_'. If there are other responses on the + // state ready to be written after starting the write, the state + // will be placed at the back of the 'ready_to_write_states_'. If + // there are no other response, the state will be marked as 'ISSUED' + // if complete final flag is not received yet from the backend or + // completed if complete final flag is received. + // The 'ongoing_write_' will reset once the completion queue returns + // a written state and no additional response on the stream is ready + // to be written. + state->context_->ongoing_write_ = true; + writing_state = state->context_->ready_to_write_states_.front(); + state->context_->ready_to_write_states_.pop(); + } + state->complete_ = true; + } + if (writing_state != nullptr) { + StateWriteResponse(writing_state); + } } } @@ -451,7 +476,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) // Decoupled state transitions // if (state->step_ == Steps::WRITTEN) { - state->context_->ongoing_write_ = false; #ifdef TRITON_ENABLE_TRACING state->trace_timestamps_.emplace_back( std::make_pair("GRPC_SEND_END", TraceManager::CaptureTimestamp())); @@ -469,54 +493,44 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) state->context_->finish_ok_ = false; } - // Finish the state if all the transactions associated with - // the state have completed. - if (state->IsComplete()) { - state->context_->DecrementRequestCounter(); - finished = Finish(state); - } else { - std::lock_guard lock(state->step_mtx_); - - // If there is an available response to be written - // to the stream, then transition directly to WRITEREADY - // state and enqueue itself to the completion queue to be - // taken up later. Otherwise, go to ISSUED state and wait - // for the callback to make a response available. - if (state->response_queue_->HasReadyResponse()) { - state->step_ = Steps::WRITEREADY; - state->context_->PutTaskBackToQueue(state); - } else { - state->step_ = Steps::ISSUED; + { + InferHandler::State* writing_state = nullptr; + std::lock_guard lk1(state->context_->mu_); + { + std::lock_guard lk2(state->step_mtx_); + if (!state->context_->ready_to_write_states_.empty()) { + writing_state = state->context_->ready_to_write_states_.front(); + state->context_->ready_to_write_states_.pop(); + } else { + state->context_->ongoing_write_ = false; + } + // Finish the state if all the transactions associated with + // the state have completed. + if (state != writing_state) { + if (state->IsComplete()) { + state->context_->DecrementRequestCounter(); + finished = Finish(state); + } else { + state->step_ = Steps::ISSUED; + } + } + } + if (writing_state != nullptr) { + StateWriteResponse(writing_state); } } } else if (state->step_ == Steps::WRITEREADY) { - if (state->delay_response_ms_ != 0) { - // Will delay the write of the response by the specified time. - // This can be used to test the flow where there are other - // responses available to be written. - LOG_INFO << "Delaying the write of the response by " - << state->delay_response_ms_ << " ms..."; - std::this_thread::sleep_for( - std::chrono::milliseconds(state->delay_response_ms_)); - } - // Finish the state if all the transactions associated with // the state have completed. if (state->IsComplete()) { state->context_->DecrementRequestCounter(); finished = Finish(state); } else { - // GRPC doesn't allow to issue another write till - // the notification from previous write has been - // delivered. If there is an ongoing write then - // defer writing and place the task at the back - // of the completion queue to be taken up later. - if (!state->context_->ongoing_write_) { - state->context_->ongoing_write_ = true; - state->context_->DecoupledWriteResponse(state); - } else { - state->context_->PutTaskBackToQueue(state); - } + LOG_ERROR << "Should not print this! Decoupled should NOT write via " + "WRITEREADY!"; + // Remove the state from the completion queue + std::lock_guard lock(state->step_mtx_); + state->step_ = Steps::ISSUED; } } } @@ -524,6 +538,31 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) return !finished; } +// For decoupled only. Caller must ensure exclusive write. +void +ModelStreamInferHandler::StateWriteResponse(InferHandler::State* state) +{ + if (state->delay_response_ms_ != 0) { + // Will delay the write of the response by the specified time. + // This can be used to test the flow where there are other + // responses available to be written. + LOG_INFO << "Delaying the write of the response by " + << state->delay_response_ms_ << " ms..."; + std::this_thread::sleep_for( + std::chrono::milliseconds(state->delay_response_ms_)); + } + { + std::lock_guard lock(state->step_mtx_); + state->step_ = Steps::WRITTEN; + // gRPC doesn't allow to issue another write till the notification from + // previous write has been delivered. + state->context_->DecoupledWriteResponse(state); + if (state->response_queue_->HasReadyResponse()) { + state->context_->ready_to_write_states_.push(state); + } + } +} + bool ModelStreamInferHandler::Finish(InferHandler::State* state) { @@ -701,45 +740,64 @@ ModelStreamInferHandler::StreamInferResponseComplete( } } - // Update states to signal that response/error is ready to write to stream - { + if (state->IsGrpcContextCancelled()) { // Need to hold lock because the handler thread processing context // cancellation might have cancelled or marked the state for cancellation. std::lock_guard lock(state->step_mtx_); - if (state->IsGrpcContextCancelled()) { - LOG_VERBOSE(1) - << "ModelStreamInferHandler::StreamInferResponseComplete, " - << state->unique_id_ - << ", skipping writing response because of transaction was cancelled"; - - // If this was the final callback for the state - // then cycle through the completion queue so - // that state object can be released. - if (is_complete) { - state->step_ = Steps::CANCELLED; - state->context_->PutTaskBackToQueue(state); - } + LOG_VERBOSE(1) + << "ModelStreamInferHandler::StreamInferResponseComplete, " + << state->unique_id_ + << ", skipping writing response because of transaction was cancelled"; - state->complete_ = is_complete; - return; + // If this was the final callback for the state + // then cycle through the completion queue so + // that state object can be released. + if (is_complete) { + state->step_ = Steps::CANCELLED; + state->context_->PutTaskBackToQueue(state); } - if (state->is_decoupled_) { + state->complete_ = is_complete; + return; + } + + if (state->is_decoupled_) { + InferHandler::State* writing_state = nullptr; + std::lock_guard lk1(state->context_->mu_); + { + std::lock_guard lk2(state->step_mtx_); + bool has_prev_ready_response = state->response_queue_->HasReadyResponse(); if (response) { state->response_queue_->MarkNextResponseComplete(); } - if (state->step_ == Steps::ISSUED) { + if (!has_prev_ready_response && response) { + state->context_->ready_to_write_states_.push(state); + } + if (!state->context_->ongoing_write_ && + !state->context_->ready_to_write_states_.empty()) { + state->context_->ongoing_write_ = true; + writing_state = state->context_->ready_to_write_states_.front(); + state->context_->ready_to_write_states_.pop(); + } + if (is_complete && state->response_queue_->IsEmpty() && + state->step_ == Steps::ISSUED) { + // The response queue is empty and complete final flag is received, so + // mark the state as 'WRITEREADY' so it can be cleaned up later. state->step_ = Steps::WRITEREADY; state->context_->PutTaskBackToQueue(state); } - } else { - state->step_ = Steps::WRITEREADY; - if (is_complete) { - state->context_->WriteResponseIfReady(state); - } + state->complete_ = is_complete; + } + if (writing_state != nullptr) { + StateWriteResponse(writing_state); + } + } else { // non-decoupled + std::lock_guard lock(state->step_mtx_); + state->step_ = Steps::WRITEREADY; + if (is_complete) { + state->context_->WriteResponseIfReady(state); } - state->complete_ = is_complete; } } diff --git a/src/grpc/stream_infer_handler.h b/src/grpc/stream_infer_handler.h index 60c4530227..e5163eac59 100644 --- a/src/grpc/stream_infer_handler.h +++ b/src/grpc/stream_infer_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -112,6 +112,7 @@ class ModelStreamInferHandler static void StreamInferResponseComplete( TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp); + static void StateWriteResponse(InferHandler::State* state); bool Finish(State* state); TraceManager* trace_manager_;