Skip to content

Commit

Permalink
[FEATURE] Support Merging LoRA Weights Into Base Model (Issue-3603) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsherstinsky authored Sep 27, 2023
1 parent a3b7709 commit 3dc8f4b
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 15 deletions.
26 changes: 24 additions & 2 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,9 @@ def train(
will contain the training statistics, TensorBoard logs, the saved
model and the training progress files.
:param random_seed: (int, default: `42`) a random seed that will be
used anywhere there is a call to a random number generator: data
splitting, parameter initialization and training set shuffling
used anywhere there is a call to a random number generator: data
splitting, parameter initialization and training set shuffling
:param kwargs: (dict, default: {}) a dictionary of optional parameters.
# Return
Expand Down Expand Up @@ -645,6 +646,9 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
)
(self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats

# For an LLM model trained with a LoRA adapter, handle merge and unload postprocessing directives.
self._merge_and_unload()

# Calibrates output feature probabilities on validation set if calibration is enabled.
# Must be done after training, and before final model parameters are saved.
if self.backend.is_coordinator():
Expand Down Expand Up @@ -807,6 +811,21 @@ def train_online(

self.model = self._online_trainer.train_online(training_dataset)

def _merge_and_unload(self) -> None:
"""For an LLM model trained with a LoRA adapter, handle merge and unload postprocessing directives.
First, check that the model is of the "llm" type. Then check if the "adapter" configuration section contains
the "postprocessor" subsection and apply the "merge_adapter_into_base_model" and "progressbar" directives.
"""
if (
self.config_obj.model_type == "llm"
and self.config_obj.adapter is not None
and self.config_obj.adapter.postprocessor is not None
and self.config_obj.adapter.postprocessor.merge_adapter_into_base_model
and hasattr(self.model, "merge_and_unload")
):
self.model.merge_and_unload(progressbar=self.config_obj.adapter.postprocessor.progressbar)

def _tune_batch_size(self, trainer, dataset, random_seed: int = default_random_seed):
"""Sets AUTO batch-size-related parameters based on the trainer, backend type, and number of workers.
Expand Down Expand Up @@ -1643,6 +1662,9 @@ def load(
# load model weights
ludwig_model.load_weights(model_dir)

# The LoRA layers appear to be loaded again (perhaps due to a potential bug); hence, we merge and unload again.
ludwig_model._merge_and_unload()

# load train set metadata
ludwig_model.training_set_metadata = backend.broadcast_return(
lambda: load_metadata(os.path.join(model_dir, TRAIN_SET_METADATA_FILE_NAME))
Expand Down
2 changes: 2 additions & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@
PROMPT = "prompt"
ADAPTER = "adapter"
PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights"
MERGE_ADAPTER_INTO_BASE_MODEL = "merge_adapter_into_base_model"
PROGRESSBAR = "progressbar"

# CrossEntropyLoss for LLMs
IGNORE_INDEX_TOKEN_ID = -100
Expand Down
15 changes: 15 additions & 0 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,21 @@ def generate(

return outputs

def merge_and_unload(self, progressbar: bool = False) -> None:
"""This method merges the LoRa layers into the base model. This is needed if someone wants to use the base
model as a standalone model. The implementation calls merge_and_unload() of the underlying LoraModel class
(in peft).
Args:
progressbar (bool): whether to show a progressbar indicating the unload and merge process
"""
from peft import LoraModel

if isinstance(self.model.base_model, LoraModel):
self.model.base_model.merge_and_unload(progressbar=progressbar)
else:
raise ValueError("This operation requires an LLM model trained with a LoRA adapter.")

def _unpack_inputs(
self,
inputs: Union[
Expand Down
28 changes: 28 additions & 0 deletions ludwig/schema/llms/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,32 @@ def wrap(config: BaseAdapterConfig):
return wrap


@DeveloperAPI
@ludwig_dataclass
class LoraPostprocessorConfig(schema_utils.BaseMarshmallowConfig):
"""This Dataclass is a schema for the nested postprocessing config under adapter of type "lora"."""

merge_adapter_into_base_model: bool = schema_utils.Boolean(
default=False,
description="""Instructs whether or not the fine-tuned LoRA weights are to be merged into the base LLM model so
that the complete fine-tuned model is available to be used and/or persisted, and then reused upon loading as a single
model (rather than having to load base and fine-tuned models separately).""",
)
progressbar: bool = schema_utils.Boolean(
default=False,
description="Instructs whether or not to show a progress bar indicating the unload and merge process.",
)


@DeveloperAPI
class LoraPostprocessorConfigField(schema_utils.DictMarshmallowField):
def __init__(self):
super().__init__(LoraPostprocessorConfig)

def _jsonschema_type_mapping(self):
return schema_utils.unload_jsonschema_from_marshmallow_class(LoraPostprocessorConfig, title="LoraPostprocessor")


@DeveloperAPI
@ludwig_dataclass
class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC):
Expand All @@ -34,6 +60,8 @@ class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC):
default=None, description="Path to pretrained weights.", allow_none=True
)

postprocessor: LoraPostprocessorConfig = LoraPostprocessorConfigField().get_default_field()

@abstractmethod
def to_config(self, **kwargs) -> "PeftConfig":
pass
Expand Down
192 changes: 179 additions & 13 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict, Tuple
from typing import Dict, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -15,11 +15,14 @@
EPOCHS,
GENERATION,
INPUT_FEATURES,
MERGE_ADAPTER_INTO_BASE_MODEL,
MODEL_LLM,
MODEL_TYPE,
OUTPUT_FEATURES,
POSTPROCESSOR,
PREPROCESSING,
PRETRAINED_ADAPTER_WEIGHTS,
PROGRESSBAR,
PROMPT,
TRAINER,
TYPE,
Expand Down Expand Up @@ -352,6 +355,84 @@ def _prepare_finetuning_test(
return train_df, prediction_df, config


def _finetune_strategy_requires_cuda(finetune_strategy_name: str, quantization_args: Union[dict, None]) -> bool:
"""This method returns whether or not a given finetine_strategy requires CUDA.
For all finetune strategies, except "qlora", the decision is based just on the name of the finetine_strategy; in the
case of qlora, if the quantization dictionary is non-empty (i.e., contains quantization specifications), then the
original finetine_strategy name of "lora" is interpreted as "qlora" and used in the lookup, based on the list of
finetine strategies requiring CUDA.
"""
cuda_only_finetune_strategy_names: list[str] = [
"prompt_tuning",
"prefix_tuning",
"p_tuning",
"qlora",
]

if finetune_strategy_name == "lora" and quantization_args:
finetune_strategy_name = "qlora"

return finetune_strategy_name in cuda_only_finetune_strategy_names


def _verify_lm_lora_finetuning_layers(
attention_layer: torch.nn.Module,
merge_adapter_into_base_model: bool,
expected_lora_in_features: int,
expected_lora_out_features: int,
) -> bool:
"""This method verifies that LoRA finetuning layers have correct types and shapes, depending on whether or not
the optional "model.merge_and_unload()" method (based on the "merge_adapter_into_base_model" directive) was
executed.
If merge_adapter_into_base_model is True, then both LoRA projection layers, V and Q, in the attention layer must
contain square weight matrices (with the dimensions expected_lora_in_features by expected_lora_in_features).
However, if merge_adapter_into_base_model is False, then the LoRA part of the attention layer must include Lora_A
and Lora_B children layers for each of V and Q projections, such that the product of V and Q matrices is a square
matrix (with the dimensions expected_lora_in_features by expected_lora_in_features) for both V and Q projections.
"""
success: bool = True
success = success and isinstance(attention_layer.v_proj, torch.nn.Linear)
success = success and isinstance(attention_layer.q_proj, torch.nn.Linear)
if merge_adapter_into_base_model:
success = success and (attention_layer.v_proj.in_features, attention_layer.v_proj.out_features) == (
expected_lora_in_features,
expected_lora_out_features,
)
success = success and (attention_layer.q_proj.in_features, attention_layer.q_proj.out_features) == (
expected_lora_in_features,
expected_lora_out_features,
)
success = success and not list(attention_layer.v_proj.children())
success = success and not list(attention_layer.q_proj.children())
else:
v_proj_named_children: dict[str, torch.nn.Modeule] = dict(attention_layer.v_proj.named_children())
assert isinstance(v_proj_named_children["lora_A"]["default"], torch.nn.Linear)
assert (
v_proj_named_children["lora_A"]["default"].in_features,
v_proj_named_children["lora_A"]["default"].out_features,
) == (expected_lora_in_features, expected_lora_out_features)
assert isinstance(v_proj_named_children["lora_B"]["default"], torch.nn.Linear)
assert (
v_proj_named_children["lora_B"]["default"].in_features,
v_proj_named_children["lora_B"]["default"].out_features,
) == (expected_lora_out_features, expected_lora_in_features)
q_proj_named_children: dict[str, torch.nn.Modeule] = dict(attention_layer.q_proj.named_children())
assert isinstance(q_proj_named_children["lora_A"]["default"], torch.nn.Linear)
assert (
q_proj_named_children["lora_A"]["default"].in_features,
q_proj_named_children["lora_A"]["default"].out_features,
) == (expected_lora_in_features, expected_lora_out_features)
assert isinstance(q_proj_named_children["lora_B"]["default"], torch.nn.Linear)
assert (
q_proj_named_children["lora_B"]["default"].in_features,
q_proj_named_children["lora_B"]["default"].out_features,
) == (expected_lora_out_features, expected_lora_in_features)

return success


# TODO(arnav): p-tuning and prefix tuning have errors when enabled that seem to stem from DDP:
#
# prefix tuning:
Expand All @@ -376,8 +457,12 @@ def _prepare_finetuning_test(
(None, {}),
("lora", {}),
("lora", {"r": 4, "dropout": 0.1}),
("lora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: True, PROGRESSBAR: True}}),
("lora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: False}}),
("adalora", {}),
("adalora", {"init_r": 8, "beta1": 0.8}),
("adalora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: True, PROGRESSBAR: True}}),
("adalora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: False}}),
("adaption_prompt", {}),
("adaption_prompt", {"adapter_len": 6, "adapter_layers": 1}),
# (
Expand All @@ -403,8 +488,12 @@ def _prepare_finetuning_test(
"full",
"lora-defaults",
"lora-modified-defaults",
"lora_merged",
"lora_not_merged",
"adalora-defaults",
"adalora-modified-defaults",
"adalora_merged",
"adalora_not_merged",
"adaption_prompt-defaults",
"adaption_prompt-modified-defaults",
# "prompt_tuning_init_random",
Expand Down Expand Up @@ -445,7 +534,10 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat
],
)
def test_llm_finetuning_strategies_quantized(tmpdir, csv_filename, finetune_strategy, adapter_args, quantization):
if quantization and (not torch.cuda.is_available() or torch.cuda.device_count() == 0):
if (
_finetune_strategy_requires_cuda(finetune_strategy_name=finetune_strategy, quantization_args=quantization)
and not (torch.cuda.is_available() and torch.cuda.device_count()) > 0
):
pytest.skip("Skip: quantization requires GPU and none are available.")

backend = LOCAL_BACKEND
Expand All @@ -469,6 +561,66 @@ def test_llm_finetuning_strategies_quantized(tmpdir, csv_filename, finetune_stra
assert preds


@pytest.mark.llm
@pytest.mark.parametrize(
"backend",
[
pytest.param(LOCAL_BACKEND, id="local"),
# TODO: Re-enable once we can run tests on GPUs
# This is because fine-tuning requires Ray with the deepspeed strategy, and deepspeed
# only works with GPUs
# pytest.param(RAY_BACKEND, id="ray"),
],
)
@pytest.mark.parametrize(
"merge_adapter_into_base_model,expected_lora_in_features,expected_lora_out_features",
[
pytest.param(
False,
32,
8,
id="lora_not_merged",
),
pytest.param(
True,
32,
32,
id="lora_merged",
),
],
)
def test_llm_lora_finetuning_merge_and_unload(
tmpdir, csv_filename, backend, merge_adapter_into_base_model, expected_lora_in_features, expected_lora_out_features
):
finetune_strategy: str = "lora"
adapter_args: dict = {
POSTPROCESSOR: {
MERGE_ADAPTER_INTO_BASE_MODEL: merge_adapter_into_base_model,
},
}
train_df, prediction_df, config = _prepare_finetuning_test(
csv_filename=csv_filename, finetune_strategy=finetune_strategy, backend=backend, adapter_args=adapter_args
)

model = LudwigModel(config)
model.train(dataset=train_df, output_directory=str(tmpdir), skip_save_processed_input=False)
assert _verify_lm_lora_finetuning_layers(
attention_layer=model.model.model.base_model.model.transformer.h[1].attn,
merge_adapter_into_base_model=merge_adapter_into_base_model,
expected_lora_in_features=expected_lora_in_features,
expected_lora_out_features=expected_lora_out_features,
)

# Make sure we can load the saved model and verify that the LoRA layers have expected shapes.
model = LudwigModel.load(os.path.join(str(tmpdir), "api_experiment_run", "model"), backend=backend)
assert _verify_lm_lora_finetuning_layers(
attention_layer=model.model.model.base_model.model.transformer.h[1].attn,
merge_adapter_into_base_model=merge_adapter_into_base_model,
expected_lora_in_features=expected_lora_in_features,
expected_lora_out_features=expected_lora_out_features,
)


@pytest.mark.llm
@pytest.mark.parametrize("use_adapter", [True, False], ids=["with_adapter", "without_adapter"])
def test_llm_training_with_gradient_checkpointing(tmpdir, csv_filename, use_adapter):
Expand Down Expand Up @@ -628,23 +780,37 @@ def test_load_pretrained_adapter_weights(adapter):

def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool:
# For a full explanation of this 8-bit workaround, see https://github.com/ludwig-ai/ludwig/pull/3606
def filter_for_weight_format(i):
"""Remove bitsandbytes metadata keys added on state dict creation.

8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict.
These contain strings that are used to reshape quantized tensors, however these have no impact until the state
dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in the
evaluation.
"""
return "weight_format" not in i[0]
# TODO: Uncomment "filter_for_weight_format()" method definition and enable its usage once GPU tests are set up.
# def filter_for_weight_format(i):
# """Remove bitsandbytes metadata keys added on state dict creation.
#
# 8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict.
# These contain strings that are used to reshape quantized tensors, however these have no impact until the state
# dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in
# the evaluation.
# """
# return "weight_format" not in i[0]

model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items())
model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items())
# model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items())
# model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items())

# Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
for key_item_1, key_item_2 in zip(model_1_filtered_state_dict, model_2_filtered_state_dict):

if model_1.__class__.__name__ != model_2.__class__.__name__:
return False

if (
hasattr(model_1, "model")
and hasattr(model_2, "model")
and not _compare_models(model_1=model_1.model, model_2=model_2.model)
):
return False

for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
if not torch.equal(key_item_1[1], key_item_2[1]):
return False

return True


Expand Down

0 comments on commit 3dc8f4b

Please sign in to comment.