diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 907fb831ad..796d5b98d7 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -19,17 +19,12 @@ from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.extras.constants import IGNORE_INDEX -from llamafactory.extras.misc import get_device_count -from llamafactory.extras.packages import is_pillow_available, is_vllm_available +from llamafactory.extras.misc import check_version, get_device_count +from llamafactory.extras.packages import is_vllm_available from llamafactory.hparams import get_infer_args from llamafactory.model import load_tokenizer -if is_pillow_available(): - from PIL import Image - from PIL.Image import Image as ImageObject - - if is_vllm_available(): from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest @@ -51,11 +46,13 @@ def vllm_infer( max_new_tokens: int = 1024, repetition_penalty: float = 1.0, pipeline_parallel_size: int = 1, + image_resolution: int = 512 * 512, ): r""" Performs batch generation using vLLM engine, which supports tensor parallelism. Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo """ + check_version("vllm>=0.4.3,<=0.6.5") if pipeline_parallel_size > get_device_count(): raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") @@ -88,15 +85,9 @@ def vllm_infer( inputs, prompts, labels = [], [], [] for sample in dataset_module["train_dataset"]: if sample["images"]: - multi_modal_data = {"image": []} - for image in sample["images"]: - if not isinstance(image, (str, ImageObject)): - raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") - - if isinstance(image, str): - image = Image.open(image).convert("RGB") - - multi_modal_data["image"].append(image) + multi_modal_data = { + "image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution) + } else: multi_modal_data = None diff --git a/setup.py b/setup.py index 75d2b7e3aa..dd0cf74aa8 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def get_console_scripts() -> List[str]: "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], - "vllm": ["vllm>=0.4.3,<0.6.7"], + "vllm": ["vllm>=0.4.3,<=0.6.5"], "galore": ["galore-torch"], "badam": ["badam>=1.2.1"], "adam-mini": ["adam-mini"], @@ -68,6 +68,7 @@ def get_console_scripts() -> List[str]: "msgpack", "referencing", "jsonschema_specifications", + "librosa", ], "modelscope": ["modelscope"], "openmind": ["openmind"], diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index e0efbb6cc8..6f61f2ab86 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -163,26 +163,25 @@ def _process_args( generation_config=GenerationConfig(**generating_args), logits_processor=get_logits_processor(), ) - - if getattr(model.config, "model_type") == "minicpmv": - gen_kwargs['input_ids'] = inputs mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) for key, value in mm_inputs.items(): if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs value = torch.stack(value) # assume they have same sizes - elif isinstance(value, list) and isinstance(value[0], list): # for minicpmv inputs - value = torch.stack([torch.stack([item for item in per_value]) for per_value in value]) + elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs + value = torch.stack([torch.stack(per_value) for per_value in value]) elif not isinstance(value, torch.Tensor): value = torch.tensor(value) - if torch.is_tensor(value) and torch.is_floating_point(value): # cast data dtype for paligemma + if torch.is_floating_point(value): # cast data dtype for paligemma value = value.to(model.dtype) - if torch.is_tensor(value): - gen_kwargs[key] = value.to(model.device) - else: - gen_kwargs[key] = value + gen_kwargs[key] = value.to(model.device) + + if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]: + gen_kwargs["input_ids"] = inputs + del gen_kwargs["image_sizes"] + gen_kwargs["tokenizer"] = tokenizer return gen_kwargs, prompt_length @@ -214,14 +213,10 @@ def _chat( videos, input_kwargs, ) - - if getattr(model.config, "model_type") == "minicpmv": - gen_kwargs['tokenizer'] = tokenizer - del gen_kwargs['image_sizes'] - generate_output = model._generate(**gen_kwargs) - else: - generate_output = model.generate(**gen_kwargs) - + generate_output = model.generate(**gen_kwargs) + if isinstance(generate_output, tuple): + generate_output = generate_output[1][0] # for minicpm_o + response_ids = generate_output[:, prompt_length:] response = tokenizer.batch_decode( response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index beaacee591..36de8b50b2 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -133,7 +133,7 @@ def _check_extra_dependencies( check_version("mixture-of-depth>=1.1.6", mandatory=True) if model_args.infer_backend == "vllm": - check_version("vllm>=0.4.3,<0.6.7") + check_version("vllm>=0.4.3,<=0.6.5") check_version("vllm", mandatory=True) if finetuning_args.use_galore: diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 80c1843619..5bd188cbdf 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -252,6 +252,13 @@ def patch_target_modules( _register_composite_model( model_type="minicpmv", + vision_model_keys=["vpm"], + language_model_keys=["llm"], +) + + +_register_composite_model( + model_type="minicpmo", vision_model_keys=["vpm", "apm", "resampler", "tts"], language_model_keys=["llm"], ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 2ce84e8602..2a7e5ddf8b 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -109,6 +109,10 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn + + if getattr(config, "model_type", None) == "minicpmo": + setattr(config, "init_audio", False) + setattr(config, "init_tts", False) if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") @@ -145,7 +149,9 @@ def patch_model( ): gen_config.do_sample = True - if "GenerationMixin" not in str(model.generate.__func__): + if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( + model.generate.__func__ + ): model.generate = MethodType(PreTrainedModel.generate, model) if add_valuehead: