Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/hiyouga/LLaMA-Factory into …
Browse files Browse the repository at this point in the history
…minicpmv
  • Loading branch information
BUAADreamer committed Jan 14, 2025
2 parents 30ff848 + 1c7663d commit 875d7ef
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 18 deletions.
23 changes: 7 additions & 16 deletions scripts/vllm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.misc import get_device_count
from llamafactory.extras.packages import is_pillow_available, is_vllm_available
from llamafactory.extras.misc import check_version, get_device_count
from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer


if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject


if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
Expand All @@ -51,11 +46,13 @@ def vllm_infer(
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
pipeline_parallel_size: int = 1,
image_resolution: int = 512 * 512,
):
r"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
check_version("vllm>=0.4.3,<=0.6.5")
if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")

Expand Down Expand Up @@ -88,15 +85,9 @@ def vllm_infer(
inputs, prompts, labels = [], [], []
for sample in dataset_module["train_dataset"]:
if sample["images"]:
multi_modal_data = {"image": []}
for image in sample["images"]:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")

if isinstance(image, str):
image = Image.open(image).convert("RGB")

multi_modal_data["image"].append(image)
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution)
}
else:
multi_modal_data = None

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_console_scripts() -> List[str]:
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<0.6.7"],
"vllm": ["vllm>=0.4.3,<=0.6.5"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)

if model_args.infer_backend == "vllm":
check_version("vllm>=0.4.3,<0.6.7")
check_version("vllm>=0.4.3,<=0.6.5")
check_version("vllm", mandatory=True)

if finetuning_args.use_galore:
Expand Down

0 comments on commit 875d7ef

Please sign in to comment.