Skip to content

Commit

Permalink
Require using lora adapter when performing quantized fine-tuning (#3492)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Aug 2, 2023
1 parent d146799 commit 734156b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
10 changes: 10 additions & 0 deletions ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,13 @@ def check_llm_quantization_backend_incompatibility(config: "ModelConfig") -> Non
backend_type = config.backend.get("type", "local")
if config.model_type == MODEL_LLM and config.quantization and backend_type != "local":
raise ConfigValidationError(f"LLM with quantization requires the 'local' backend, found: '{backend_type}'")


@register_config_check
def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821
"""Checks that all the necessary settings are in place for QLoRA."""
if config.model_type != MODEL_LLM or config.trainer.type == "none":
return

if config.quantization and (not config.adapter or config.adapter.type != "lora"):
raise ConfigValidationError("Fine-tuning and LLM with quantization requires using the 'lora' adapter")
33 changes: 33 additions & 0 deletions tests/ludwig/config_validation/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,36 @@ def test_check_llm_quantization_backend_incompatibility():
del config["quantization"]
config["backend"] = {"type": "ray"}
ModelConfig.from_dict(config)


def test_check_qlora():
config = yaml.safe_load(
"""
model_type: llm
base_model: facebook/opt-350m
quantization:
bits: 4
input_features:
- name: sample
type: text
output_features:
- name: label
type: text
trainer:
type: finetune
"""
)

with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)

config["adapter"] = {
"type": "adaption_prompt",
}
with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)

config["adapter"] = {
"type": "lora",
}
ModelConfig.from_dict(config)

0 comments on commit 734156b

Please sign in to comment.