Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 #6631

Merged
merged 6 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_console_scripts() -> List[str]:
"msgpack",
"referencing",
"jsonschema_specifications",
"librosa",
],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
Expand Down
10 changes: 10 additions & 0 deletions src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def _process_args(
for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(per_value) for per_value in value])
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)

Expand All @@ -176,6 +178,11 @@ def _process_args(

gen_kwargs[key] = value.to(model.device)

if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer

return gen_kwargs, prompt_length

@staticmethod
Expand Down Expand Up @@ -207,6 +214,9 @@ def _chat(
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # for minicpm_o

response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
Expand Down
7 changes: 7 additions & 0 deletions src/llamafactory/model/model_utils/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ def patch_target_modules(

_register_composite_model(
model_type="minicpmv",
vision_model_keys=["vpm"],
language_model_keys=["llm"],
)


_register_composite_model(
model_type="minicpmo",
vision_model_keys=["vpm", "apm", "resampler", "tts"],
language_model_keys=["llm"],
)
Expand Down
8 changes: 7 additions & 1 deletion src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def patch_config(

if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn

if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False)
setattr(config, "init_tts", False)

if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
Expand Down Expand Up @@ -145,7 +149,9 @@ def patch_model(
):
gen_config.do_sample = True

if "GenerationMixin" not in str(model.generate.__func__):
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)

if add_valuehead:
Expand Down
Loading