Skip to content

Commit

Permalink
Add LoHa Pass
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Jan 30, 2025
1 parent f4c4f38 commit 1ced888
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ Please also find the detailed options from following table for each pass:
| [SplitModel](../../reference/pass.rst#_split_model) | Split an ONNX model into multiple smaller sub-models based on predefined assignments. |
| [LoRA](../../reference/pass.rst#_lora) | Run LoRA fine-tuning on a Hugging Face PyTorch model. |
| [QLoRA](../../reference/pass.rst#_qlora) | Run QLoRA fine-tuning on a Hugging Face PyTorch model. |
| [LoHa](../../reference/pass.rst#_loha) | Run LoHa fine-tuning on a Hugging Face PyTorch model. |
| [LoftQ](../../reference/pass.rst#_loftq) | Run LoftQ fine-tuning on a Hugging Face PyTorch model. |
| [QuantizationAwareTraining](../../reference/pass.rst#_onnx_quantization_aware_training) | Run quantization aware training on PyTorch model. |
| [OpenVINOConversion](../../reference/pass.rst#_openvino_conversion) | Converts PyTorch, ONNX or TensorFlow Model to OpenVino Model. |
Expand Down
6 changes: 6 additions & 0 deletions docs/source/reference/pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ QLoRA
-----
.. autoconfigclass:: olive.passes.QLoRA

.. _loha:

LoHa
-----
.. autoconfigclass:: olive.passes.LoHa

.. _loftq:

LoftQ
Expand Down
2 changes: 1 addition & 1 deletion examples/llama2/llama2_qlora.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"max_steps": 150,
"logging_steps": 50.0
},
"lora_r": 64,
"r": 64,
"lora_alpha": 16,
"eval_data_config": "eval_data"
},
Expand Down
4 changes: 2 additions & 2 deletions examples/phi3/phi3_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"type": "LoRA",
"train_data_config": "tiny_codes_train",
"eval_data_config": "tiny_codes_eval",
"lora_r": 64,
"r": 64,
"training_args": {
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand All @@ -80,7 +80,7 @@
"type": "QLoRA",
"train_data_config": "tiny_codes_train",
"eval_data_config": "tiny_codes_eval",
"lora_r": 64,
"r": 64,
"training_args": {
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion olive/cli/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_run_config(self, tempdir: str) -> Dict:
((*finetune_key, "type"), self.args.method),
((*finetune_key, "torch_dtype"), self.args.torch_dtype),
((*finetune_key, "training_args"), self.parse_training_args()),
((*finetune_key, "lora_r"), self.args.lora_r),
((*finetune_key, "r"), self.args.lora_r),
((*finetune_key, "lora_alpha"), self.args.lora_alpha),
("output_dir", self.args.output_path),
("log_severity_level", self.args.log_level),
Expand Down
221 changes: 192 additions & 29 deletions olive/passes/pytorch/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,18 @@ def create_training_args(self) -> transformers.TrainingArguments:


class LoRABase(Pass):
"""Base class for LoRA and QLoRA fine-tuning passes."""
"""Base class for LoRA fine-tuning passes."""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
"lora_r": PassConfigParam(
"r": PassConfigParam(
type_=int,
default_value=64,
search_defaults=Categorical([16, 32, 64]),
description="Lora R dimension.",
),
"lora_alpha": PassConfigParam(
type_=float, default_value=16, description="The alpha parameter for Lora scaling."
description="R dimension.",
),
"alpha": PassConfigParam(type_=float, default_value=16, description="The alpha parameter for scaling."),
"lora_dropout": PassConfigParam(
type_=float, default_value=0.05, description="The dropout probability for Lora layers."
),
Expand Down Expand Up @@ -286,9 +284,7 @@ def data_generator(data_config):
return train_dataset, eval_dataset

@staticmethod
def prepare_model_for_lora_finetuning(
model: "PreTrainedModel", use_gradient_checkpointing: bool
) -> "PreTrainedModel":
def prepare_model_for_finetuning(model: "PreTrainedModel", use_gradient_checkpointing: bool) -> "PreTrainedModel":
"""Prepare the model for fine-tuning.
Freeze base model's layers and prepare model for gradient checkpointing if necessary.
Expand All @@ -299,6 +295,13 @@ def prepare_model_for_lora_finetuning(
:param use_gradient_checkpointing: Whether to use gradient checkpointing.
:return: The prepared model.
"""
if use_gradient_checkpointing and not model.supports_gradient_checkpointing:
logger.warning(
"gradient_checkpointing is True, but model does not support gradient checkpointing! Setting"
" gradient_checkpoing to False"
)
use_gradient_checkpointing = False

for param in model.parameters():
# freeze base model's layers
param.requires_grad = False
Expand Down Expand Up @@ -393,8 +396,8 @@ def init_lora_adapters(

peft_task_type = get_peft_task_type_from_task(task, fail_on_not_found=True)
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
r=config.r,
lora_alpha=config.alpha,
lora_dropout=config.lora_dropout,
target_modules=target_modules,
bias="none",
Expand Down Expand Up @@ -430,14 +433,7 @@ def enable_lora(
from peft import PeftModel

logger.debug("Enabling LoRA fine-tuning")
if config.training_args.gradient_checkpointing and not model.supports_gradient_checkpointing:
logger.warning(
"gradient_checkpointing is True, but model does not support gradient checkpointing! Setting"
" gradient_checkpoing to False"
)
config.training_args.gradient_checkpointing = False

model = self.prepare_model_for_lora_finetuning(model, config.training_args.gradient_checkpointing)
model = self.prepare_model_for_finetuning(model, config.training_args.gradient_checkpointing)

# set model_parallel and is_parallelizable to True
# we are using "auto" device_map, so model_parallel is True or doing DDP
Expand Down Expand Up @@ -587,6 +583,20 @@ def count_trainable_parameters(model) -> str:
f"|| trainable%: {100 * trainable_params / all_param:.2f}"
)

@staticmethod
def check_target_modules(model: HfModelHandler):
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING

model_type = model.get_hf_model_type()
if model_type not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
if model_type in MODELS_TO_LORA_TARGET_MODULES_MAPPING:
return MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_type]
else:
raise ValueError(
f"Model type {model_type} is not recognized by peft or olive. Please provide 'target_modules'."
)
return None


class LoRA(LoRABase):
"""Run LoRA fine-tuning on a Hugging Face PyTorch model."""
Expand All @@ -600,8 +610,6 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
return config

def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler:
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING

# convert config to pass config class
# this will validate the config and convert to the correct types
config = self._config_class(**config)
Expand All @@ -613,14 +621,7 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_
config.training_args = config.training_args or HFTrainingArguments()

# check if peft or olive has target modules for the model
model_type = model.get_hf_model_type()
if not config.target_modules and model_type not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
if model_type in MODELS_TO_LORA_TARGET_MODULES_MAPPING:
config.target_modules = MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_type]
else:
raise ValueError(
f"Model type {model_type} is not recognized by peft or olive. Please provide 'target_modules'."
)
config.target_modules = config.target_modules or self.check_target_modules(model)

# get new model
pytorch_model = self.load_base_pytorch_model(model, config)
Expand Down Expand Up @@ -770,6 +771,168 @@ def get_quant_model(
return deepcopy(model), pytorch_model, bnb_quant_config, quantized_modules


class LoHa(LoRABase):
"""Run LoHa fine-tuning on a Hugging Face PyTorch model."""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
config = {
"rank_dropout": PassConfigParam(
type_=float,
default_value=0.05,
description="The dropout probability for rank dimension during training.",
),
"module_dropout": PassConfigParam(
type_=float,
default_value=0.05,
description="The dropout probability for disabling LoHa modules during training.",
),
"use_effective_conv2d": PassConfigParam(
type_=bool,
default_value=True,
description="Use parameter effective decomposition for Conv2d with ksize > 1.",
),
"target_modules": PassConfigParam(
type_=Optional[Union[List[str], str]],
default_value="all-linear",
description="The names of the modules to apply the adapter to.",
),
"exclude_modules": PassConfigParam(
type_=Optional[Union[List[str], str]], default_value=None, description="Modules to exclude from LoHa."
),
"init_weights": PassConfigParam(
type_=bool, default_value=True, description="Whether to perform initialization of adapter weights."
),
"layers_to_transform": PassConfigParam(
type_=List[int], default_value=None, description="The layer indices to transform."
),
"layers_pattern": PassConfigParam(
type_=List[str],
default_value=None,
description="The layer pattern name, used only if layers_to_transform is different from None.",
),
"rank_pattern": PassConfigParam(
type_=Dict,
default_value={},
description="The mapping from layer names or regexp expression "
"to ranks which are different from the default rank specified by r.",
),
"alpha_pattern": PassConfigParam(
type_=Dict,
default_value={},
description="The mapping from layer names or regexp expression "
"to alphas which are different from the default alpha specified by alpha.",
),
}
config.update(super()._default_config(accelerator_spec))
return config

def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler:
# convert config to pass config class
# this will validate the config and convert to the correct types
config = self._config_class(**config)

# check dependencies
self.check_dependencies(config)

# use default training args if not provided
config.training_args = config.training_args or HFTrainingArguments()

# check if peft or olive has target modules for the model
config.target_modules = config.target_modules or self.check_target_modules(model)

# get new model
pytorch_model = self.load_base_pytorch_model(model, config)

# add loha modules
pytorch_model = self.enable_loha(pytorch_model, config)

# train and return new model
return self.train_and_save_new_model(
pytorch_model, model.get_hf_tokenizer(), config, deepcopy(model), output_model_path
)

def enable_loha(
self,
model: "PreTrainedModel",
config: ConfigBase,
) -> "PeftModel":
"""Enable LoHa fine-tuning on a Hugging Face PyTorch model.
Add padding token to tokenizer and resize model embedding layer if needed.
Prepare model for fine-tuning by freezing master weights and enabling gradient checkpointing if needed.
Load or initialize LoHa adapters.
:param model: The Hugging Face PyTorch model to enable LoHa fine-tuning on.
:param config: The config for the pass run.
:return: The LoHa model.
"""
logger.debug("Enabling LoHa fine-tuning")
model = self.prepare_model_for_finetuning(model, config.training_args.gradient_checkpointing)

# set model_parallel and is_parallelizable to True
# we are using "auto" device_map, so model_parallel is True or doing DDP
# don't want the trainer to do Data Parallel
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)

logger.debug(
"The number of trainable parameters in the original model: %s", self.count_trainable_parameters(model)
)
logger.debug("Initializing LoHa adapters from config")
loha_model = self.init_loha_adapters(model, config)
logger.debug(
"The number of trainable parameters in the LoHa model: %s", self.count_trainable_parameters(loha_model)
)
# no need to cast loha modules to model's dtype, we dont do peft.prepare_model_for_kbit_training so the modules
# are already in the same dtype as the model
# casting to dtype is risky since for awq quant linear, it also casts the scales to dtype and but the qlinear
# expects scales to be in half
return loha_model

def init_loha_adapters(
self,
model: "PreTrainedModel",
config: ConfigBase,
) -> "PeftModel":
"""Initialize LoHa adapters.
:param model: The Hugging Face PyTorch model to add LoHa adapters to.
:param config: The config for the pass run.
:return: The LoHa model.
"""
from peft import LoHaConfig, LoHaModel

loha_config = LoHaConfig(
r=config.r,
alpha=config.alpha,
rank_dropout=config.rank_dropout,
module_dropout=config.module_dropout,
use_effective_conv2d=config.use_effective_conv2d,
target_modules=config.target_modules,
exclude_modules=config.exclude_modules,
init_weights=config.init_weights,
layers_to_transform=config.layers_to_transform,
layers_pattern=config.layers_pattern,
rank_pattern=config.rank_pattern,
alpha_pattern=config.alpha_pattern,
modules_to_save=config.modules_to_save,
)

return LoHaModel(model, loha_config, "default")

@classmethod
def check_dependencies(cls, config: ConfigBase):
"""Check dependencies for the pass."""
super().check_dependencies(config)

from peft import __version__ as peft_version

# LoHa is only supported after peft 0.12.0
if version.parse(peft_version) < version.parse("0.12.0"):
raise ImportError(f"Please install peft >= 0.12.0 to use {cls.__name__} pass.")


class LoftQ(QLoRABase):
"""Run LoftQ fine-tuning on a Hugging Face PyTorch model."""

Expand Down
16 changes: 13 additions & 3 deletions test/unit_test/passes/pytorch/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# --------------------------------------------------------------------------
import platform
from pathlib import Path
from unittest.mock import patch

import pytest
import torch
Expand All @@ -12,7 +13,7 @@
from olive.data.template import huggingface_data_config_template
from olive.model import HfModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.pytorch.lora import LoftQ, LoRA, QLoRA
from olive.passes.pytorch.lora import LoftQ, LoHa, LoRA, QLoRA

# pylint: disable=redefined-outer-name

Expand Down Expand Up @@ -42,9 +43,9 @@ def get_pass_config(model_name, task, **kwargs):
return {
"train_data_config": data_config,
# hidden sizes are 4 or 16
# will have invalid adapter weights since `in_features` and/or `out_features` say 64 (lora_r) even though
# will have invalid adapter weights since `in_features` and/or `out_features` say 64 (r) even though
# the actual weights are 4 or 16. Bug not from our code, it's from peft
"lora_r": 4,
"r": 4,
"training_args": {
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down Expand Up @@ -112,3 +113,12 @@ def test_loftq(tmp_path):
# assert
assert Path(out.get_resource("model_path")).exists()
assert Path(out.get_resource("adapter_path")).exists()


@patch("transformers.Trainer.train")
def test_loha(mock_train, tmp_path):
# execute
out = run_finetuning(LoHa, tmp_path, torch_dtype="float32")

# assert
assert Path(out.get_resource("adapter_path")).exists()

0 comments on commit 1ced888

Please sign in to comment.