diff --git a/vllm/entrypoints/queue_llm.py b/vllm/entrypoints/queue_llm.py index 56d6b611f8a37..f2a06f115bdbf 100644 --- a/vllm/entrypoints/queue_llm.py +++ b/vllm/entrypoints/queue_llm.py @@ -18,7 +18,6 @@ from vllm.utils import Counter, deprecate_kwargs import multiprocessing as mp import queue -import threading logger = init_logger(__name__) @@ -167,28 +166,36 @@ def start(self, use_tqdm: bool = True): self._pull_tokens_from_input_queue(block=True) self._run_engine(use_tqdm=use_tqdm) + def _pull_all_tokens_from_input_queue(self, block: bool = True): + while self._pull_tokens_from_input_queue(block=False): + pass + if block: + self._pull_tokens_from_input_queue(block) + def _pull_tokens_from_input_queue(self, block: bool = True): try: input = self.input_queue.get() if block else self.input_queue.get_nowait() if input is None: self.finish = True - for sample_id, token_ids in input: - inputs = self._convert_v1_inputs( - prompts=None, - prompt_token_ids=token_ids, - multi_modal_data=None, - ) - - self._validate_and_add_requests( - inputs=inputs, - params=self.sampling_params, - request_id=sample_id, - ) + else: + for sample_id, token_ids in input: + inputs = self._convert_v1_inputs( + prompts=None, + prompt_token_ids=token_ids, + multi_modal_data=None, + ) + + self._validate_and_add_requests( + inputs=inputs, + params=self.sampling_params, + request_id=sample_id, + ) except queue.Empty: - pass + return False except Exception as e: logger.error(f"Unexpected exception during pulling tokens: {e}") - + return False + return True def _convert_v1_inputs( self, @@ -291,24 +298,31 @@ def _run_engine( ) # Run the engine. total_toks = 0 - first_token_sent = set() - while not self.finish and self.llm_engine.has_unfinished_requests(): - self._pull_tokens_from_input_queue(block=False) + request_stats = {} + while not self.finish or self.llm_engine.has_unfinished_requests(): + block = not self.llm_engine.has_unfinished_requests() and not self.finish + self._pull_all_tokens_from_input_queue(block=block) step_outputs = self.llm_engine.step() for output in step_outputs: - if len(output.outputs) > 0 and (output.request_id not in first_token_sent): - self.first_token_queue.put(output) - first_token_sent.add(output.request_id) - if output.finished: - self.result_queue.put_nowait(output) - first_token_sent.remove(output.request_id) - if use_tqdm: - if isinstance(output, RequestOutput): - # Calculate tokens only for RequestOutput - total_toks += sum( - len(stp.token_ids) for stp in output.outputs) - spd = total_toks / pbar.format_dict["elapsed"] - pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" - pbar.update(1) + output_len = len(output.outputs[0].token_ids) + if output_len > 0 and (output.request_id not in request_stats): + self.first_token_queue.put((output.request_id, output.outputs[0].token_ids)) + request_stats[output.request_id] = output_len + if request_stats[output.request_id] < output_len: + self.result_queue.put_nowait((output.request_id, output.outputs[0].token_ids[request_stats[output.request_id]: output_len])) + if output.finished: + # signal end of stream with None + self.result_queue.put_nowait((output.request_id, None)) + del request_stats[output.request_id] + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + total_toks += sum( + len(stp.token_ids) for stp in output.outputs) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" + pbar.update(1) + else: + request_stats[output.request_id] = output_len if use_tqdm: pbar.close()