Skip to content

Commit

Permalink
perf: Check for cancellation on response thread (#54)
Browse files Browse the repository at this point in the history
Co-authored-by: Iman Tabrizian <[email protected]>
  • Loading branch information
kthui and Tabrizian authored Aug 7, 2024
1 parent a345a1d commit 843cbdd
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,13 @@ def response_loop(self):
# To signal shutdown a None item will be added to the queue.
if item is None:
break
response_sender, response, response_flag = item
response_state, response, response_flag = item
response_sender = response_state["response_sender"]
try:
response_sender.send(response, response_flag)
# Stop checking for cancellation if the last response is generated.
if not response_state["last_response_generated"]:
response_state["is_cancelled"] = response_sender.is_cancelled()
except Exception as e:
self.logger.log_error(
f"An error occurred while sending a response: {e}"
Expand Down Expand Up @@ -338,6 +342,11 @@ async def generate(self, request):
Forwards single request to LLM engine and returns responses.
"""
response_sender = request.get_response_sender()
response_state = {
"response_sender": response_sender,
"is_cancelled": False,
"last_response_generated": False, # last response ready but not yet sent
}
self.ongoing_request_count += 1
decrement_ongoing_request_count = True
try:
Expand Down Expand Up @@ -399,10 +408,26 @@ async def generate(self, request):
)

async for output in response_iterator:
if response_sender.is_cancelled():
is_cancelled = response_state["is_cancelled"]
if not stream:
is_cancelled = response_sender.is_cancelled()
if is_cancelled:
self.logger.log_info("[vllm] Cancelling the request")
await self.llm_engine.abort(request_id)
self.logger.log_info("[vllm] Successfully cancelled the request")
if stream:
response_state["last_response_generated"] = True
response = pb_utils.InferenceResponse(
error=pb_utils.TritonError(
message="Request was cancelled",
code=pb_utils.TritonError.CANCELLED,
)
)
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
decrement_ongoing_request_count = False
self._response_queue.put_nowait(
(response_state, response, flags)
)
break
if stream:
prev_outputs_lengths = None
Expand All @@ -414,9 +439,10 @@ async def generate(self, request):
response = self.create_stream_response(output, prev_outputs_lengths)
flags = 0
if output.finished:
response_state["last_response_generated"] = True
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
decrement_ongoing_request_count = False
self._response_queue.put_nowait((response_sender, response, flags))
self._response_queue.put_nowait((response_state, response, flags))
prev_outputs = output

last_output = output
Expand Down

0 comments on commit 843cbdd

Please sign in to comment.