From 1e11d2696f11753cef89ff983f0db369cb6ca02b Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Tue, 26 Nov 2024 13:49:05 +0800 Subject: [PATCH] works now --- vllm/engine/llm_engine.py | 1 - vllm/engine/mm_arg_utils.py | 2 +- vllm/engine/multiprocessing/__init__.py | 5 ++++ vllm/engine/multiprocessing/mm_client.py | 11 ++++++- vllm/engine/multiprocessing/mm_engine.py | 30 +++++++++---------- vllm/entrypoints/openai/mm_api_server.py | 1 - vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 3 +- 8 files changed, 34 insertions(+), 21 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e774b3cbd726d..71db41e0ffe47 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1443,7 +1443,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) - # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: diff --git a/vllm/engine/mm_arg_utils.py b/vllm/engine/mm_arg_utils.py index cb80c6a51cded..9ed775441bae1 100644 --- a/vllm/engine/mm_arg_utils.py +++ b/vllm/engine/mm_arg_utils.py @@ -911,7 +911,7 @@ def create_model_config(self, model:str = None) -> ModelConfig: model=model if model is not None else self.model, task=self.task, # We know this is not None because we set it in __post_init__ - tokenizer=cast(str, self.tokenizer), + tokenizer=cast(str, model), tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, allowed_local_media_path=self.allowed_local_media_path, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 34c161e9395ae..46b95986fa96c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -27,6 +27,7 @@ class RPCProcessRequest: prompt: PromptType params: Union[SamplingParams, PoolingParams] request_id: str + model: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None @@ -39,6 +40,7 @@ def __init__( inputs: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, + model: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -52,6 +54,7 @@ def __init__( prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, + model: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -68,6 +71,7 @@ def __init__( prompt: Optional[PromptType] = None, params: Optional[Union[SamplingParams, PoolingParams]] = None, request_id: Optional[str] = None, + model: Optional[str] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -85,6 +89,7 @@ def __init__( self.prompt = prompt self.params = params self.request_id = request_id + self.model = model self.lora_request = lora_request self.trace_headers = trace_headers self.prompt_adapter_request = prompt_adapter_request diff --git a/vllm/engine/multiprocessing/mm_client.py b/vllm/engine/multiprocessing/mm_client.py index 6491c45471a37..99c963ab0792e 100644 --- a/vllm/engine/multiprocessing/mm_client.py +++ b/vllm/engine/multiprocessing/mm_client.py @@ -361,6 +361,12 @@ async def get_input_preprocessor(self) -> InputPreprocessor: async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) + async def get_tokenizer_mm(self, model, lora_request: Optional[LoRARequest] = None): + for tokenizer in self.tokenizers: + if tokenizer.tokenizer_id == model: + return await tokenizer.get_lora_tokenizer_async(lora_request) + raise ValueError(f"Tokenizer for model {model} not found.") + async def get_decoding_config(self) -> DecodingConfig: return self.decoding_config @@ -458,6 +464,7 @@ def generate( prompt: Optional[PromptType] = None, sampling_params: Optional[SamplingParams] = None, request_id: Optional[str] = None, + model: Optional[str] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -489,7 +496,7 @@ def generate( assert (prompt is not None and sampling_params is not None and request_id is not None) - return self._process_request(prompt, sampling_params, request_id, + return self._process_request(prompt, sampling_params, request_id, model, lora_request, trace_headers, prompt_adapter_request, priority) @@ -570,6 +577,7 @@ async def _process_request( prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, + model: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -618,6 +626,7 @@ async def _process_request( prompt=prompt, params=params, request_id=request_id, + model=model, lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/mm_engine.py b/vllm/engine/multiprocessing/mm_engine.py index 0d96255a4de94..6f4729f398a34 100644 --- a/vllm/engine/multiprocessing/mm_engine.py +++ b/vllm/engine/multiprocessing/mm_engine.py @@ -71,20 +71,19 @@ def __init__(self, # get configs from args and kwargs, determine how many models to load vllm_config = kwargs.get('vllm_config') - print(f"aaaa {vllm_config}") models_load = [model_config.model for model_config in vllm_config.model_configs ] self.engines = [] for i, model in enumerate(models_load): - print(f"create engine for model: {model}") vllm_config.model_config = vllm_config.model_configs[i] self.engines.append(LLMEngine(model=model, *args, **kwargs)) self.log_requests = log_requests self.use_async_sockets = use_async_sockets - # if self.use_async_sockets: - # self.engine.process_request_outputs_callback = \ - # self._async_socket_engine_callback + if self.use_async_sockets: + for engine in self.engines: + engine.process_request_outputs_callback = \ + self._async_socket_engine_callback self.ctx = zmq.Context() # type: ignore[attr-defined] @@ -215,7 +214,7 @@ def engine_step(self) -> List[RequestOutput]: try: res = [] for engine in self.engines: - res.append(engine.step()) + res += engine.step() return res except SystemExit: raise @@ -269,14 +268,16 @@ def _handle_process_request(self, request: RPCProcessRequest): self._send_outputs(rpc_err) try: - self.engine.add_request( - request_id=request_id, - prompt=request.prompt, - params=request.params, - lora_request=request.lora_request, - trace_headers=request.trace_headers, - prompt_adapter_request=request.prompt_adapter_request, - priority=request.priority) + for engine in self.engines: + if engine.model_config.model == request.model: + engine.add_request( + request_id=request_id, + prompt=request.prompt, + params=request.params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + priority=request.priority) if self.log_requests: logger.info("Added request %s.", request.request_id) @@ -372,7 +373,6 @@ def signal_handler(*_) -> None: def run_mm_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str, engine_alive): try: - print(f"bbbb {engine_args}") engine = MMLLMEngine.from_engine_args(engine_args=engine_args, usage_context=usage_context, ipc_path=ipc_path) diff --git a/vllm/entrypoints/openai/mm_api_server.py b/vllm/entrypoints/openai/mm_api_server.py index 297caab8543ca..99db253c9da33 100644 --- a/vllm/entrypoints/openai/mm_api_server.py +++ b/vllm/entrypoints/openai/mm_api_server.py @@ -128,7 +128,6 @@ async def build_async_engine_client_from_engine_args( Returns the Client or None if the creation failed. """ - # Fall back # TODO: fill out feature matrix. if (MMLLMEngineClient.is_unsupported_config(engine_args) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 54ca0463bcab1..db048bf04c2da 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -123,7 +123,7 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request) tool_parser = self.tool_parser diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 936aae8f1c267..44ecf1798a997 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -99,7 +99,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request) request_prompts, engine_prompts = self._preprocess_completion( request, @@ -148,6 +148,7 @@ async def create_completion( engine_prompt, sampling_params, request_id_item, + request.model, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers,