diff --git a/ludwig/api.py b/ludwig/api.py index 58053a87625..96b9bbd80a1 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -429,8 +429,9 @@ def train( will contain the training statistics, TensorBoard logs, the saved model and the training progress files. :param random_seed: (int, default: `42`) a random seed that will be - used anywhere there is a call to a random number generator: data - splitting, parameter initialization and training set shuffling + used anywhere there is a call to a random number generator: data + splitting, parameter initialization and training set shuffling + :param kwargs: (dict, default: {}) a dictionary of optional parameters. # Return @@ -645,6 +646,9 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): ) (self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats + # For an LLM model trained with a LoRA adapter, handle merge and unload postprocessing directives. + self._merge_and_unload() + # Calibrates output feature probabilities on validation set if calibration is enabled. # Must be done after training, and before final model parameters are saved. if self.backend.is_coordinator(): @@ -807,6 +811,21 @@ def train_online( self.model = self._online_trainer.train_online(training_dataset) + def _merge_and_unload(self) -> None: + """For an LLM model trained with a LoRA adapter, handle merge and unload postprocessing directives. + + First, check that the model is of the "llm" type. Then check if the "adapter" configuration section contains + the "postprocessor" subsection and apply the "merge_adapter_into_base_model" and "progressbar" directives. + """ + if ( + self.config_obj.model_type == "llm" + and self.config_obj.adapter is not None + and self.config_obj.adapter.postprocessor is not None + and self.config_obj.adapter.postprocessor.merge_adapter_into_base_model + and hasattr(self.model, "merge_and_unload") + ): + self.model.merge_and_unload(progressbar=self.config_obj.adapter.postprocessor.progressbar) + def _tune_batch_size(self, trainer, dataset, random_seed: int = default_random_seed): """Sets AUTO batch-size-related parameters based on the trainer, backend type, and number of workers. @@ -1643,6 +1662,9 @@ def load( # load model weights ludwig_model.load_weights(model_dir) + # The LoRA layers appear to be loaded again (perhaps due to a potential bug); hence, we merge and unload again. + ludwig_model._merge_and_unload() + # load train set metadata ludwig_model.training_set_metadata = backend.broadcast_return( lambda: load_metadata(os.path.join(model_dir, TRAIN_SET_METADATA_FILE_NAME)) diff --git a/ludwig/constants.py b/ludwig/constants.py index d2cc455df24..0bf5eebe9e8 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -283,6 +283,8 @@ PROMPT = "prompt" ADAPTER = "adapter" PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights" +MERGE_ADAPTER_INTO_BASE_MODEL = "merge_adapter_into_base_model" +PROGRESSBAR = "progressbar" # CrossEntropyLoss for LLMs IGNORE_INDEX_TOKEN_ID = -100 diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index d74a16e56c0..73800259f99 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -451,6 +451,21 @@ def generate( return outputs + def merge_and_unload(self, progressbar: bool = False) -> None: + """This method merges the LoRa layers into the base model. This is needed if someone wants to use the base + model as a standalone model. The implementation calls merge_and_unload() of the underlying LoraModel class + (in peft). + + Args: + progressbar (bool): whether to show a progressbar indicating the unload and merge process + """ + from peft import LoraModel + + if isinstance(self.model.base_model, LoraModel): + self.model.base_model.merge_and_unload(progressbar=progressbar) + else: + raise ValueError("This operation requires an LLM model trained with a LoRA adapter.") + def _unpack_inputs( self, inputs: Union[ diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index 3ce30aeb07c..b7a5afc8c37 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -25,6 +25,32 @@ def wrap(config: BaseAdapterConfig): return wrap +@DeveloperAPI +@ludwig_dataclass +class LoraPostprocessorConfig(schema_utils.BaseMarshmallowConfig): + """This Dataclass is a schema for the nested postprocessing config under adapter of type "lora".""" + + merge_adapter_into_base_model: bool = schema_utils.Boolean( + default=False, + description="""Instructs whether or not the fine-tuned LoRA weights are to be merged into the base LLM model so +that the complete fine-tuned model is available to be used and/or persisted, and then reused upon loading as a single +model (rather than having to load base and fine-tuned models separately).""", + ) + progressbar: bool = schema_utils.Boolean( + default=False, + description="Instructs whether or not to show a progress bar indicating the unload and merge process.", + ) + + +@DeveloperAPI +class LoraPostprocessorConfigField(schema_utils.DictMarshmallowField): + def __init__(self): + super().__init__(LoraPostprocessorConfig) + + def _jsonschema_type_mapping(self): + return schema_utils.unload_jsonschema_from_marshmallow_class(LoraPostprocessorConfig, title="LoraPostprocessor") + + @DeveloperAPI @ludwig_dataclass class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC): @@ -34,6 +60,8 @@ class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC): default=None, description="Path to pretrained weights.", allow_none=True ) + postprocessor: LoraPostprocessorConfig = LoraPostprocessorConfigField().get_default_field() + @abstractmethod def to_config(self, **kwargs) -> "PeftConfig": pass diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 215da2f079e..b1981afdf32 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -1,5 +1,5 @@ import os -from typing import Dict, Tuple +from typing import Dict, Tuple, Union import numpy as np import pandas as pd @@ -15,11 +15,14 @@ EPOCHS, GENERATION, INPUT_FEATURES, + MERGE_ADAPTER_INTO_BASE_MODEL, MODEL_LLM, MODEL_TYPE, OUTPUT_FEATURES, + POSTPROCESSOR, PREPROCESSING, PRETRAINED_ADAPTER_WEIGHTS, + PROGRESSBAR, PROMPT, TRAINER, TYPE, @@ -352,6 +355,84 @@ def _prepare_finetuning_test( return train_df, prediction_df, config +def _finetune_strategy_requires_cuda(finetune_strategy_name: str, quantization_args: Union[dict, None]) -> bool: + """This method returns whether or not a given finetine_strategy requires CUDA. + + For all finetune strategies, except "qlora", the decision is based just on the name of the finetine_strategy; in the + case of qlora, if the quantization dictionary is non-empty (i.e., contains quantization specifications), then the + original finetine_strategy name of "lora" is interpreted as "qlora" and used in the lookup, based on the list of + finetine strategies requiring CUDA. + """ + cuda_only_finetune_strategy_names: list[str] = [ + "prompt_tuning", + "prefix_tuning", + "p_tuning", + "qlora", + ] + + if finetune_strategy_name == "lora" and quantization_args: + finetune_strategy_name = "qlora" + + return finetune_strategy_name in cuda_only_finetune_strategy_names + + +def _verify_lm_lora_finetuning_layers( + attention_layer: torch.nn.Module, + merge_adapter_into_base_model: bool, + expected_lora_in_features: int, + expected_lora_out_features: int, +) -> bool: + """This method verifies that LoRA finetuning layers have correct types and shapes, depending on whether or not + the optional "model.merge_and_unload()" method (based on the "merge_adapter_into_base_model" directive) was + executed. + + If merge_adapter_into_base_model is True, then both LoRA projection layers, V and Q, in the attention layer must + contain square weight matrices (with the dimensions expected_lora_in_features by expected_lora_in_features). + However, if merge_adapter_into_base_model is False, then the LoRA part of the attention layer must include Lora_A + and Lora_B children layers for each of V and Q projections, such that the product of V and Q matrices is a square + matrix (with the dimensions expected_lora_in_features by expected_lora_in_features) for both V and Q projections. + """ + success: bool = True + success = success and isinstance(attention_layer.v_proj, torch.nn.Linear) + success = success and isinstance(attention_layer.q_proj, torch.nn.Linear) + if merge_adapter_into_base_model: + success = success and (attention_layer.v_proj.in_features, attention_layer.v_proj.out_features) == ( + expected_lora_in_features, + expected_lora_out_features, + ) + success = success and (attention_layer.q_proj.in_features, attention_layer.q_proj.out_features) == ( + expected_lora_in_features, + expected_lora_out_features, + ) + success = success and not list(attention_layer.v_proj.children()) + success = success and not list(attention_layer.q_proj.children()) + else: + v_proj_named_children: dict[str, torch.nn.Modeule] = dict(attention_layer.v_proj.named_children()) + assert isinstance(v_proj_named_children["lora_A"]["default"], torch.nn.Linear) + assert ( + v_proj_named_children["lora_A"]["default"].in_features, + v_proj_named_children["lora_A"]["default"].out_features, + ) == (expected_lora_in_features, expected_lora_out_features) + assert isinstance(v_proj_named_children["lora_B"]["default"], torch.nn.Linear) + assert ( + v_proj_named_children["lora_B"]["default"].in_features, + v_proj_named_children["lora_B"]["default"].out_features, + ) == (expected_lora_out_features, expected_lora_in_features) + q_proj_named_children: dict[str, torch.nn.Modeule] = dict(attention_layer.q_proj.named_children()) + assert isinstance(q_proj_named_children["lora_A"]["default"], torch.nn.Linear) + assert ( + q_proj_named_children["lora_A"]["default"].in_features, + q_proj_named_children["lora_A"]["default"].out_features, + ) == (expected_lora_in_features, expected_lora_out_features) + assert isinstance(q_proj_named_children["lora_B"]["default"], torch.nn.Linear) + assert ( + q_proj_named_children["lora_B"]["default"].in_features, + q_proj_named_children["lora_B"]["default"].out_features, + ) == (expected_lora_out_features, expected_lora_in_features) + + return success + + # TODO(arnav): p-tuning and prefix tuning have errors when enabled that seem to stem from DDP: # # prefix tuning: @@ -376,8 +457,12 @@ def _prepare_finetuning_test( (None, {}), ("lora", {}), ("lora", {"r": 4, "dropout": 0.1}), + ("lora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: True, PROGRESSBAR: True}}), + ("lora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: False}}), ("adalora", {}), ("adalora", {"init_r": 8, "beta1": 0.8}), + ("adalora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: True, PROGRESSBAR: True}}), + ("adalora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: False}}), ("adaption_prompt", {}), ("adaption_prompt", {"adapter_len": 6, "adapter_layers": 1}), # ( @@ -403,8 +488,12 @@ def _prepare_finetuning_test( "full", "lora-defaults", "lora-modified-defaults", + "lora_merged", + "lora_not_merged", "adalora-defaults", "adalora-modified-defaults", + "adalora_merged", + "adalora_not_merged", "adaption_prompt-defaults", "adaption_prompt-modified-defaults", # "prompt_tuning_init_random", @@ -445,7 +534,10 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat ], ) def test_llm_finetuning_strategies_quantized(tmpdir, csv_filename, finetune_strategy, adapter_args, quantization): - if quantization and (not torch.cuda.is_available() or torch.cuda.device_count() == 0): + if ( + _finetune_strategy_requires_cuda(finetune_strategy_name=finetune_strategy, quantization_args=quantization) + and not (torch.cuda.is_available() and torch.cuda.device_count()) > 0 + ): pytest.skip("Skip: quantization requires GPU and none are available.") backend = LOCAL_BACKEND @@ -469,6 +561,66 @@ def test_llm_finetuning_strategies_quantized(tmpdir, csv_filename, finetune_stra assert preds +@pytest.mark.llm +@pytest.mark.parametrize( + "backend", + [ + pytest.param(LOCAL_BACKEND, id="local"), + # TODO: Re-enable once we can run tests on GPUs + # This is because fine-tuning requires Ray with the deepspeed strategy, and deepspeed + # only works with GPUs + # pytest.param(RAY_BACKEND, id="ray"), + ], +) +@pytest.mark.parametrize( + "merge_adapter_into_base_model,expected_lora_in_features,expected_lora_out_features", + [ + pytest.param( + False, + 32, + 8, + id="lora_not_merged", + ), + pytest.param( + True, + 32, + 32, + id="lora_merged", + ), + ], +) +def test_llm_lora_finetuning_merge_and_unload( + tmpdir, csv_filename, backend, merge_adapter_into_base_model, expected_lora_in_features, expected_lora_out_features +): + finetune_strategy: str = "lora" + adapter_args: dict = { + POSTPROCESSOR: { + MERGE_ADAPTER_INTO_BASE_MODEL: merge_adapter_into_base_model, + }, + } + train_df, prediction_df, config = _prepare_finetuning_test( + csv_filename=csv_filename, finetune_strategy=finetune_strategy, backend=backend, adapter_args=adapter_args + ) + + model = LudwigModel(config) + model.train(dataset=train_df, output_directory=str(tmpdir), skip_save_processed_input=False) + assert _verify_lm_lora_finetuning_layers( + attention_layer=model.model.model.base_model.model.transformer.h[1].attn, + merge_adapter_into_base_model=merge_adapter_into_base_model, + expected_lora_in_features=expected_lora_in_features, + expected_lora_out_features=expected_lora_out_features, + ) + + # Make sure we can load the saved model and verify that the LoRA layers have expected shapes. + model = LudwigModel.load(os.path.join(str(tmpdir), "api_experiment_run", "model"), backend=backend) + assert _verify_lm_lora_finetuning_layers( + attention_layer=model.model.model.base_model.model.transformer.h[1].attn, + merge_adapter_into_base_model=merge_adapter_into_base_model, + expected_lora_in_features=expected_lora_in_features, + expected_lora_out_features=expected_lora_out_features, + ) + + @pytest.mark.llm @pytest.mark.parametrize("use_adapter", [True, False], ids=["with_adapter", "without_adapter"]) def test_llm_training_with_gradient_checkpointing(tmpdir, csv_filename, use_adapter): @@ -628,23 +780,37 @@ def test_load_pretrained_adapter_weights(adapter): def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: # For a full explanation of this 8-bit workaround, see https://github.com/ludwig-ai/ludwig/pull/3606 - def filter_for_weight_format(i): - """Remove bitsandbytes metadata keys added on state dict creation. - 8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict. - These contain strings that are used to reshape quantized tensors, however these have no impact until the state - dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in the - evaluation. - """ - return "weight_format" not in i[0] + # TODO: Uncomment "filter_for_weight_format()" method definition and enable its usage once GPU tests are set up. + # def filter_for_weight_format(i): + # """Remove bitsandbytes metadata keys added on state dict creation. + # + # 8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict. + # These contain strings that are used to reshape quantized tensors, however these have no impact until the state + # dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in + # the evaluation. + # """ + # return "weight_format" not in i[0] - model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items()) - model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items()) + # model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items()) + # model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items()) # Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 - for key_item_1, key_item_2 in zip(model_1_filtered_state_dict, model_2_filtered_state_dict): + + if model_1.__class__.__name__ != model_2.__class__.__name__: + return False + + if ( + hasattr(model_1, "model") + and hasattr(model_2, "model") + and not _compare_models(model_1=model_1.model, model_2=model_2.model) + ): + return False + + for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): if not torch.equal(key_item_1[1], key_item_2[1]): return False + return True