diff --git a/setup.py b/setup.py index 1fbe95a944..dd0cf74aa8 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ def get_console_scripts() -> List[str]: "msgpack", "referencing", "jsonschema_specifications", + "librosa", ], "modelscope": ["modelscope"], "openmind": ["openmind"], diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 879c407a55..6f61f2ab86 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -168,6 +168,8 @@ def _process_args( for key, value in mm_inputs.items(): if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs value = torch.stack(value) # assume they have same sizes + elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs + value = torch.stack([torch.stack(per_value) for per_value in value]) elif not isinstance(value, torch.Tensor): value = torch.tensor(value) @@ -176,6 +178,11 @@ def _process_args( gen_kwargs[key] = value.to(model.device) + if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]: + gen_kwargs["input_ids"] = inputs + del gen_kwargs["image_sizes"] + gen_kwargs["tokenizer"] = tokenizer + return gen_kwargs, prompt_length @staticmethod @@ -207,6 +214,9 @@ def _chat( input_kwargs, ) generate_output = model.generate(**gen_kwargs) + if isinstance(generate_output, tuple): + generate_output = generate_output[1][0] # for minicpm_o + response_ids = generate_output[:, prompt_length:] response = tokenizer.batch_decode( response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 80c1843619..5bd188cbdf 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -252,6 +252,13 @@ def patch_target_modules( _register_composite_model( model_type="minicpmv", + vision_model_keys=["vpm"], + language_model_keys=["llm"], +) + + +_register_composite_model( + model_type="minicpmo", vision_model_keys=["vpm", "apm", "resampler", "tts"], language_model_keys=["llm"], ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 2ce84e8602..2a7e5ddf8b 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -109,6 +109,10 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn + + if getattr(config, "model_type", None) == "minicpmo": + setattr(config, "init_audio", False) + setattr(config, "init_tts", False) if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") @@ -145,7 +149,9 @@ def patch_model( ): gen_config.do_sample = True - if "GenerationMixin" not in str(model.generate.__func__): + if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( + model.generate.__func__ + ): model.generate = MethodType(PreTrainedModel.generate, model) if add_valuehead: