diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index ea7871c8194..ea7f2190549 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -323,6 +323,11 @@ def set_llm_parameters(config: "ModelConfig") -> None: # PEFT PR: https://github.com/huggingface/peft/pull/1375 _set_phi2_target_modules(config) + # HACK(Arnav): Set Gemma target modules when using LoRA + # GitHub issue: https://github.com/ludwig-ai/ludwig/issues/3937 + # PEFT PR: https://github.com/huggingface/peft/pull/1499 + _set_gemma_target_modules(config) + def _set_llm_tokenizers(config: "ModelConfig") -> None: """Sets the tokenizers for the LLM model to the pretrained model name or path. This ensures that they use the @@ -451,6 +456,22 @@ def _set_phi2_target_modules(config: "ModelConfig") -> None: config.adapter.target_modules = target_modules +def _set_gemma_target_modules(config: "ModelConfig") -> None: + """If the base model is Gemma, LoRA is enabled and the target modules are not set, set the target modules to + maximize performance.""" + if config.base_model not in {"google/gemma-2b", "google/gemma-2b-it", "google/gemma-7b", "google/gemma-7b-it"}: + return + + if not config.adapter: + return + + if config.adapter.type != "lora" or config.adapter.target_modules: + return + + target_modules = ["q_proj", "v_proj"] + config.adapter.target_modules = target_modules + + @DeveloperAPI def contains_grid_search_parameters(hyperopt_config: HyperoptConfigDict) -> bool: """Returns True if any hyperopt parameter in the config is using the grid_search space."""