Skip to content

Commit

Permalink
update format
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Nov 27, 2024
1 parent 157f649 commit 3b76699
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 19 deletions.
10 changes: 5 additions & 5 deletions scripts/run_multi_models.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# VLLM_SKIP_WARMUP=true python3 -m \
# vllm.entrypoints.openai.mm_api_server \
# --model mistralai/Mistral-7B-Instruct-v0.3 meta-llama/Llama-3.1-8B-Instruct \
# --port 8080 --device hpu --dtype bfloat16 \
# --gpu-memory-utilization=0.3 --use-v2-block-manager --max-model-len 4096 2>&1 > multi_models.log &
VLLM_SKIP_WARMUP=true python3 -m \
vllm.entrypoints.openai.mm_api_server \
--model mistralai/Mistral-7B-Instruct-v0.3 meta-llama/Llama-3.1-8B-Instruct \
--port 8080 --device hpu --dtype bfloat16 \
--gpu-memory-utilization=0.3 --use-v2-block-manager --max-model-len 4096 2>&1 > multi_models.log &


bs=128
Expand Down
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,8 @@ class VllmConfig:
quant_config: Optional[QuantizationConfig] = None
compilation_config: CompilationConfig = field(default=None,
init=True) # type: ignore
model_configs: List[ModelConfig] = field(default=None, init=True) # type: ignore
model_configs: List[ModelConfig] = field(default=None,
init=True) # type: ignore

@staticmethod
def _get_quantization_config(
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __init__(
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
model: str=None,
model: str = None,
) -> None:

self.model_config = vllm_config.model_config
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/mm_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,10 +903,10 @@ def from_cli_args(cls, args: argparse.Namespace):
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args

def create_model_configs(self)-> list[ModelConfig]:
def create_model_configs(self) -> list[ModelConfig]:
return [self.create_model_config(model) for model in self.models]

def create_model_config(self, model:str = None) -> ModelConfig:
def create_model_config(self, model: str = None) -> ModelConfig:
return ModelConfig(
model=model if model is not None else self.model,
task=self.task,
Expand Down
8 changes: 5 additions & 3 deletions vllm/engine/multiprocessing/mm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ async def get_input_preprocessor(self) -> InputPreprocessor:
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

async def get_tokenizer_mm(self, model, lora_request: Optional[LoRARequest] = None):
async def get_tokenizer_mm(self,
model,
lora_request: Optional[LoRARequest] = None):
for tokenizer in self.tokenizers:
if tokenizer.tokenizer_id == model:
return await tokenizer.get_lora_tokenizer_async(lora_request)
Expand Down Expand Up @@ -496,8 +498,8 @@ def generate(
assert (prompt is not None and sampling_params is not None
and request_id is not None)

return self._process_request(prompt, sampling_params, request_id, model,
lora_request, trace_headers,
return self._process_request(prompt, sampling_params, request_id,
model, lora_request, trace_headers,
prompt_adapter_request, priority)

@overload # DEPRECATED
Expand Down
14 changes: 9 additions & 5 deletions vllm/engine/multiprocessing/mm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,15 @@ def __init__(self,
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
kwargs['use_cached_outputs'] = True

# get configs from args and kwargs, determine how many models to load
original_vllm_config = kwargs.get('vllm_config')
models_load = [model_config.model for model_config in original_vllm_config.model_configs ]
self.engines = []

models_load = [
model_config.model
for model_config in original_vllm_config.model_configs
]
self.engines = []

for i, model in enumerate(models_load):
vllm_config = copy.deepcopy(original_vllm_config)
vllm_config.model_config = original_vllm_config.model_configs[i]
Expand Down Expand Up @@ -192,7 +195,8 @@ def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""

while True:
if not any(engine.has_unfinished_requests() for engine in self.engines):
if not any(engine.has_unfinished_requests()
for engine in self.engines):
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def create_chat_completion(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request)
tokenizer = await self.engine_client.get_tokenizer_mm(
request.model, lora_request)

tool_parser = self.tool_parser

Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ async def create_completion(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request)
tokenizer = await self.engine_client.get_tokenizer_mm(
request.model, lora_request)

request_prompts, engine_prompts = await self._preprocess_completion(
request,
Expand Down

0 comments on commit 3b76699

Please sign in to comment.