Skip to content

Commit

Permalink
Add default LoRA target modules for Gemma (#3936)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 authored Feb 22, 2024
1 parent cde34ef commit a1af35b
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def set_llm_parameters(config: "ModelConfig") -> None:
# PEFT PR: https://github.com/huggingface/peft/pull/1375
_set_phi2_target_modules(config)

# HACK(Arnav): Set Gemma target modules when using LoRA
# GitHub issue: https://github.com/ludwig-ai/ludwig/issues/3937
# PEFT PR: https://github.com/huggingface/peft/pull/1499
_set_gemma_target_modules(config)


def _set_llm_tokenizers(config: "ModelConfig") -> None:
"""Sets the tokenizers for the LLM model to the pretrained model name or path. This ensures that they use the
Expand Down Expand Up @@ -451,6 +456,22 @@ def _set_phi2_target_modules(config: "ModelConfig") -> None:
config.adapter.target_modules = target_modules


def _set_gemma_target_modules(config: "ModelConfig") -> None:
"""If the base model is Gemma, LoRA is enabled and the target modules are not set, set the target modules to
maximize performance."""
if config.base_model not in {"google/gemma-2b", "google/gemma-2b-it", "google/gemma-7b", "google/gemma-7b-it"}:
return

if not config.adapter:
return

if config.adapter.type != "lora" or config.adapter.target_modules:
return

target_modules = ["q_proj", "v_proj"]
config.adapter.target_modules = target_modules


@DeveloperAPI
def contains_grid_search_parameters(hyperopt_config: HyperoptConfigDict) -> bool:
"""Returns True if any hyperopt parameter in the config is using the grid_search space."""
Expand Down

0 comments on commit a1af35b

Please sign in to comment.