diff --git a/docs/commands.md b/docs/commands.md index 3566e0f6..8a15b09c 100755 --- a/docs/commands.md +++ b/docs/commands.md @@ -27,20 +27,38 @@ This mode supports a number of command-line arguments, the details of which can > install sglang ```bash -git clone https://github.com/EvolvingLMMs-Lab/sglang.git +git clone https://github.com/sgl-project/sglang.git +# Current version is tested on #1222 cd sglang; -git checkout dev/onevision_local; pip install -e "python[srt]" + +# Install FlashInfer CUDA kernels +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` > run sglang backend service with the following command ```bash -# backend service -python -m sglang.launch_server --model-path "\path\to\onevision" --tokenizer-path lmms-lab/llavanext-qwen-siglip-tokenizer --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava +# After update, there is no need to use an extra command to setup backend server +# the server will be initialized in the init process # launch lmms-eval srt_api model -python -m accelerate.commands.launch --main_process_port=12580 --num_processes=1 lmms_eval --model=srt_api --model_args=modality=image,host=127.0.0.1,port=30000 --tasks=ai2d --batch_size=1 --log_samples --log_samples_suffix=debug --output_path=./logs/ --verbosity=DEBUG +CKPT_PATH=$1 +TASK=$2 +MODALITY=$3 +TP_SIZE=$4 +echo $TASK +TASK_SUFFIX="${TASK//,/_}" +echo $TASK_SUFFIX + +python3 -m lmms_eval \ + --model srt_api \ + --model_args modality=$MODALITY,model_version=$CKPT_PATH,tp=$TP_SIZE,host=127.0.0.1,port=30000,timeout=600 \ + --tasks $TASK \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix $TASK_SUFFIX \ + --output_path ./logs/ ``` You may need to install some dependencies for the above command to work (if you encounter some errors). @@ -48,7 +66,6 @@ You may need to install some dependencies for the above command to work (if you ```bash pip install httpx==0.23.3 pip install protobuf==3.20 -pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ ``` diff --git a/docs/run_examples.md b/docs/run_examples.md index 99c2e716..15c5715b 100644 --- a/docs/run_examples.md +++ b/docs/run_examples.md @@ -209,6 +209,7 @@ accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ ### SRT API MODEL To enable faster testing speed for larger llava model, you can use this srt api model to enable testing through sglang. +You will need to first glone sglang from "https://github.com/sgl-project/sglang". Current version is tested on the commit #1222 of sglang Here are the scripts if you want to test the result in one script. ```bash @@ -223,22 +224,16 @@ python3 -m pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ CKPT_PATH=$1 -TOK_PATH=$2 -TASK=$3 -MODALITY=$4 +TASK=$2 +MODALITY=$3 +TP_SIZE=$4 echo $TASK TASK_SUFFIX="${TASK//,/_}" echo $TASK_SUFFIX -# Serve sglang in backend -python3 -m sglang.launch_server --model-path ${CKPT_PATH} --tokenizer-path ${TOK_PATH} --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava & - -# Wait till the server is ready -sleep 360; - python3 -m lmms_eval \ --model srt_api \ - --model_args modality=$MODALITY,host=127.0.0.1,port=30000,timeout=600 \ + --model_args modality=$MODALITY,model_version=$CKPT_PATH,tp=$TP_SIZE,host=127.0.0.1,port=30000,timeout=600 \ --tasks $TASK \ --batch_size 1 \ --log_samples \ diff --git a/lmms_eval/models/srt_api.py b/lmms_eval/models/srt_api.py index ca9d9e90..55a1ae65 100755 --- a/lmms_eval/models/srt_api.py +++ b/lmms_eval/models/srt_api.py @@ -1,10 +1,12 @@ from accelerate import Accelerator, DistributedType +import asyncio import base64 from io import BytesIO from copy import deepcopy from decord import VideoReader, cpu import numpy as np -from openai import OpenAI +from multiprocessing import cpu_count +from openai import AsyncOpenAI from PIL import Image import os import json @@ -18,6 +20,12 @@ from loguru import logger as eval_logger +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + popen_launch_server, +) + NUM_SECONDS_TO_SLEEP = 5 @@ -25,15 +33,19 @@ class SRT_API(lmms): def __init__( self, - api_key: str = "EMPTY", - model_version: str = "default", + api_key: str = "sk-123456", + model_version: str = "lmms-lab/llava-onevision-qwen2-72b-ov", modality: str = "video", host: str = "127.0.0.1", port: int = 30000, max_frames_num: int = 32, timeout: int = 60, + chat_template: str = "chatml-llava", + tp: int = 8, + chunked_prefill_size: int = 16384, continual_mode: bool = False, response_persistent_folder: str = None, + num_processes: int = cpu_count() // 2, **kwargs, ) -> None: super().__init__() @@ -63,7 +75,24 @@ def __init__( self.cache_mode = "start" accelerator = Accelerator() - self.client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1") + self.model = model_version + self.base_url = f"http://{host}:{port}" + self.api_key = api_key + self.chat_template = chat_template + other_args = [] + other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) + other_args.extend(["--tensor-parallel-size", str(tp)]) + other_args.extend(["--chat-template", self.chat_template]) + self.process = popen_launch_server( + self.model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=self.api_key, + other_args=other_args, + ) + self.base_url += "/v1" + self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + self.num_processes = num_processes # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." if accelerator.num_processes > 1: assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." @@ -130,100 +159,81 @@ def flatten(self, input): new_list.append(j) return new_list + async def generate(self, request): + contexts, gen_kwargs, doc_to_visual, doc_id, task, split = request.args + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + for visual in visuals: + if self.modality == "image": + img = self.encode_image(visual) + imgs.append(img) + elif self.modality == "video": + try: + frames = self.encode_video(visual, self.max_frames_num) + imgs.extend(frames) + except Exception as e: + eval_logger.error(f"Exception : {e} \n When loading video {visual}") + imgs = None + break + + # Handling video decode error + # If we can't even load using pyav, then we will skip + if imgs is None: + resps = "" + return resps + + messages = [] + + # put the images in the first place + content = [] + for img in imgs: + content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + + content.append({"type": "text", "text": contexts}) + messages.append({"role": "user", "content": content}) + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + + for attempt in range(5): + try: + response = await self.client.chat.completions.create(model=self.model_version, messages=messages, temperature=gen_kwargs["temperature"], max_tokens=gen_kwargs["max_new_tokens"], timeout=self.timeout) + response_text = response.choices[0].message.content.strip() + break # If successful, break out of the loop + + except Exception as e: + eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.") + if attempt < 4: + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty string + eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.") + response_text = "" + + return response_text + def generate_until(self, requests) -> List[str]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") - for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: - if self.continual_mode is True and self.cache_mode == "resume": - doc_uuid = f"{task}___{split}___{doc_id}" - if doc_uuid in self.response_cache: - response_text = self.response_cache[doc_uuid] - if response_text: - res.append(response_text) - pbar.update(1) - continue - - visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] - visuals = self.flatten(visuals) - imgs = [] # multiple images or frames for video - for visual in visuals: - if self.modality == "image": - img = self.encode_image(visual) - imgs.append(img) - elif self.modality == "video": - try: - frames = self.encode_video(visual, self.max_frames_num) - imgs.extend(frames) - except Exception as e: - eval_logger.error(f"Exception : {e} \n When loading video {visual}") - imgs = None - break - - # Handling video decode error - # If we can't even load using pyav, then we will skip - if imgs is None: - resps = "" - res.append(resps) + async def run(requests): + sem = asyncio.Semaphore(self.num_processes) + + async def _process(request): + async with sem: + return await self.generate(request) + + tasks = [asyncio.create_task(_process(request)) for request in requests] + for completed_task in asyncio.as_completed(tasks): + result = await completed_task + res.append(result) pbar.update(1) - continue - - messages = [] - - # put the images in the first place - content = [] - for img in imgs: - content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - - content.append({"type": "text", "text": contexts}) - messages.append({"role": "user", "content": content}) - # if self.image_token not in contexts: # single image format - # content = [] - # for img in imgs: - # content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - - # content.append({"type": "text", "text": contexts}) - # messages.append({"role": "user", "content": content}) - # else: # interleaved format - # contexts = contexts.split(self.image_token) - # for idx, img in enumerate(imgs): - # content = [ - # {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}, - # {"type": "text", "text": contexts[idx]}, - # ] - # messages.append({"role": "user", "content": content}) - # messages.append({"role": "user", "content": [{"type": "text", "text": contexts[-1]}]}) - - if "max_new_tokens" not in gen_kwargs: - gen_kwargs["max_new_tokens"] = 1024 - - if "temperature" not in gen_kwargs: - gen_kwargs["temperature"] = 0 - - for attempt in range(5): - try: - response = self.client.chat.completions.create(model=self.model_version, messages=messages, temperature=gen_kwargs["temperature"], max_tokens=gen_kwargs["max_new_tokens"], timeout=self.timeout) - response_text = response.choices[0].message.content.strip() - break # If successful, break out of the loop - except Exception as e: - eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.") - if attempt < 4: - time.sleep(NUM_SECONDS_TO_SLEEP) - else: # If this was the last attempt, log and return empty string - eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.") - response_text = "" - - res.append(response_text) - pbar.update(1) - - if self.continual_mode is True: # Cache the response - doc_uuid = f"{task}___{split}___{doc_id}" - self.response_cache[doc_uuid] = response_text - with open(self.response_persistent_file, "w") as f: - json.dump(self.response_cache, f) - - pbar.close() + asyncio.run(run(requests)) + return res def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: