Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/hiyouga/LLaMA-Factory into …
Browse files Browse the repository at this point in the history
…minicpmv
  • Loading branch information
BUAADreamer committed Jan 14, 2025
2 parents 30ff848 + 1c7663d commit eaba4a7
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 37 deletions.
23 changes: 7 additions & 16 deletions scripts/vllm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.misc import get_device_count
from llamafactory.extras.packages import is_pillow_available, is_vllm_available
from llamafactory.extras.misc import check_version, get_device_count
from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer


if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject


if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
Expand All @@ -51,11 +46,13 @@ def vllm_infer(
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
pipeline_parallel_size: int = 1,
image_resolution: int = 512 * 512,
):
r"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
check_version("vllm>=0.4.3,<=0.6.5")
if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")

Expand Down Expand Up @@ -88,15 +85,9 @@ def vllm_infer(
inputs, prompts, labels = [], [], []
for sample in dataset_module["train_dataset"]:
if sample["images"]:
multi_modal_data = {"image": []}
for image in sample["images"]:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")

if isinstance(image, str):
image = Image.open(image).convert("RGB")

multi_modal_data["image"].append(image)
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution)
}
else:
multi_modal_data = None

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_console_scripts() -> List[str]:
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<0.6.7"],
"vllm": ["vllm>=0.4.3,<=0.6.5"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
Expand All @@ -68,6 +68,7 @@ def get_console_scripts() -> List[str]:
"msgpack",
"referencing",
"jsonschema_specifications",
"librosa",
],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
Expand Down
31 changes: 13 additions & 18 deletions src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,26 +163,25 @@ def _process_args(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)

if getattr(model.config, "model_type") == "minicpmv":
gen_kwargs['input_ids'] = inputs

mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and isinstance(value[0], list): # for minicpmv inputs
value = torch.stack([torch.stack([item for item in per_value]) for per_value in value])
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(per_value) for per_value in value])
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)

if torch.is_tensor(value) and torch.is_floating_point(value): # cast data dtype for paligemma
if torch.is_floating_point(value): # cast data dtype for paligemma
value = value.to(model.dtype)

if torch.is_tensor(value):
gen_kwargs[key] = value.to(model.device)
else:
gen_kwargs[key] = value
gen_kwargs[key] = value.to(model.device)

if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer

return gen_kwargs, prompt_length

Expand Down Expand Up @@ -214,14 +213,10 @@ def _chat(
videos,
input_kwargs,
)

if getattr(model.config, "model_type") == "minicpmv":
gen_kwargs['tokenizer'] = tokenizer
del gen_kwargs['image_sizes']
generate_output = model._generate(**gen_kwargs)
else:
generate_output = model.generate(**gen_kwargs)

generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # for minicpm_o

response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)

if model_args.infer_backend == "vllm":
check_version("vllm>=0.4.3,<0.6.7")
check_version("vllm>=0.4.3,<=0.6.5")
check_version("vllm", mandatory=True)

if finetuning_args.use_galore:
Expand Down
7 changes: 7 additions & 0 deletions src/llamafactory/model/model_utils/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ def patch_target_modules(

_register_composite_model(
model_type="minicpmv",
vision_model_keys=["vpm"],
language_model_keys=["llm"],
)


_register_composite_model(
model_type="minicpmo",
vision_model_keys=["vpm", "apm", "resampler", "tts"],
language_model_keys=["llm"],
)
Expand Down
8 changes: 7 additions & 1 deletion src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def patch_config(

if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn

if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False)
setattr(config, "init_tts", False)

if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
Expand Down Expand Up @@ -145,7 +149,9 @@ def patch_model(
):
gen_config.do_sample = True

if "GenerationMixin" not in str(model.generate.__func__):
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)

if add_valuehead:
Expand Down

0 comments on commit eaba4a7

Please sign in to comment.