diff --git a/src/model.py b/src/model.py index 42df8fe1..0ed0e454 100644 --- a/src/model.py +++ b/src/model.py @@ -287,9 +287,13 @@ def response_loop(self): # To signal shutdown a None item will be added to the queue. if item is None: break - response_sender, response, response_flag = item + response_state, response, response_flag = item + response_sender = response_state["response_sender"] try: response_sender.send(response, response_flag) + # Stop checking for cancellation if the last response is generated. + if not response_state["last_response_generated"]: + response_state["is_cancelled"] = response_sender.is_cancelled() except Exception as e: self.logger.log_error( f"An error occurred while sending a response: {e}" @@ -338,6 +342,11 @@ async def generate(self, request): Forwards single request to LLM engine and returns responses. """ response_sender = request.get_response_sender() + response_state = { + "response_sender": response_sender, + "is_cancelled": False, + "last_response_generated": False, # last response ready but not yet sent + } self.ongoing_request_count += 1 decrement_ongoing_request_count = True try: @@ -399,10 +408,26 @@ async def generate(self, request): ) async for output in response_iterator: - if response_sender.is_cancelled(): + is_cancelled = response_state["is_cancelled"] + if not stream: + is_cancelled = response_sender.is_cancelled() + if is_cancelled: self.logger.log_info("[vllm] Cancelling the request") await self.llm_engine.abort(request_id) self.logger.log_info("[vllm] Successfully cancelled the request") + if stream: + response_state["last_response_generated"] = True + response = pb_utils.InferenceResponse( + error=pb_utils.TritonError( + message="Request was cancelled", + code=pb_utils.TritonError.CANCELLED, + ) + ) + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + decrement_ongoing_request_count = False + self._response_queue.put_nowait( + (response_state, response, flags) + ) break if stream: prev_outputs_lengths = None @@ -414,9 +439,10 @@ async def generate(self, request): response = self.create_stream_response(output, prev_outputs_lengths) flags = 0 if output.finished: + response_state["last_response_generated"] = True flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL decrement_ongoing_request_count = False - self._response_queue.put_nowait((response_sender, response, flags)) + self._response_queue.put_nowait((response_state, response, flags)) prev_outputs = output last_output = output