From 48f7d1ffc5736607fd1935222ee32b92f8079ad5 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 24 Jul 2023 09:18:05 -0700 Subject: [PATCH] Add QLoRA for 4-bit fine-tuning (#3476) --- examples/llama2_7b_finetuning_4bit/README.md | 22 +++++ .../llama2_7b_finetuning_4bit/train_alpaca.py | 58 ++++++++++++ ludwig/models/llm.py | 19 +++- ludwig/schema/llms/quantization.py | 89 +++++++++++++++++++ ludwig/schema/model_types/llm.py | 2 + requirements_llm.txt | 2 +- tests/integration_tests/test_llm.py | 18 ++-- tests/ludwig/schema/test_model_config.py | 27 ++++++ 8 files changed, 230 insertions(+), 7 deletions(-) create mode 100644 examples/llama2_7b_finetuning_4bit/README.md create mode 100644 examples/llama2_7b_finetuning_4bit/train_alpaca.py create mode 100644 ludwig/schema/llms/quantization.py diff --git a/examples/llama2_7b_finetuning_4bit/README.md b/examples/llama2_7b_finetuning_4bit/README.md new file mode 100644 index 00000000000..9af187fbcf0 --- /dev/null +++ b/examples/llama2_7b_finetuning_4bit/README.md @@ -0,0 +1,22 @@ +# Llama2-7b Fine-Tuning 4bit (QLoRA) + +This example shows how to fine-tune [Llama2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) to follow instructions. +Instruction tuning is the first step in adapting a general purpose Large Language Model into a chatbot. + +This example uses no distributed training or big data functionality. It is designed to run locally on any machine +with GPU availability. + +## Prerequisites + +- [HuggingFace API Token](https://huggingface.co/docs/hub/security-tokens) +- Access approval to [Llama2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) +- GPU with at least 12 GiB of VRAM (in our tests, we used an Nvidia T4) + +## Running the example + +Set your token environment variable from the terminal, then run the API script: + +```bash +export HUGGING_FACE_HUB_TOKEN="" +python train_alpaca.py +``` diff --git a/examples/llama2_7b_finetuning_4bit/train_alpaca.py b/examples/llama2_7b_finetuning_4bit/train_alpaca.py new file mode 100644 index 00000000000..587ad21d1c8 --- /dev/null +++ b/examples/llama2_7b_finetuning_4bit/train_alpaca.py @@ -0,0 +1,58 @@ +import logging +import os + +import yaml + +from ludwig.api import LudwigModel + +config = yaml.safe_load( + """ +model_type: llm +base_model: meta-llama/Llama-2-7b-hf + +quantization: + bits: 4 + +adapter: + type: lora + +input_features: + - name: instruction + type: text + +output_features: + - name: output + type: text + +trainer: + type: finetune + learning_rate: 0.0003 + batch_size: 2 + gradient_accumulation_steps: 8 + epochs: 3 + learning_rate_scheduler: + warmup_fraction: 0.01 + +backend: + type: local +""" +) + +# Define Ludwig model object that drive model training +model = LudwigModel(config=config, logging_level=logging.INFO) + +# initiate model training +( + train_stats, # dictionary containing training statistics + preprocessed_data, # tuple Ludwig Dataset objects of pre-processed training data + output_directory, # location of training results stored on disk +) = model.train( + dataset="ludwig://alpaca", + experiment_name="alpaca_instruct_4bit", + model_name="llama2_7b", +) + +# list contents of output directory +print("contents of output directory:", output_directory) +for item in os.listdir(output_directory): + print("\t", item) diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 06c25c19693..92a2557c6ee 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -91,8 +91,14 @@ def __init__( self.model_name = self.config_obj.base_model self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model) + self.load_kwargs = {} + if self.config_obj.quantization: + # Apply quanitzation configuration at model load time + self.load_kwargs["torch_dtype"] = getattr(torch, self.config_obj.quantization.bnb_4bit_compute_dtype) + self.load_kwargs["quantization_config"] = self.config_obj.quantization.to_bitsandbytes() + logger.info("Loading large language model...") - self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model) + self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model, **self.load_kwargs) # Model initially loaded onto cpu self.curr_device = torch.device("cpu") @@ -195,12 +201,20 @@ def initialize_adapter(self): def prepare_for_training(self): # TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this. + if self.config_obj.quantization: + self.prepare_for_quantized_training() self.initialize_adapter() + def prepare_for_quantized_training(self): + from peft import prepare_model_for_kbit_training + + self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=False) + def to_device(self, device): device = torch.device(device) if device == self.curr_device: + log_once(f"Model already on device'{device}'.") return self else: log_once(f"Moving LLM from '{self.curr_device}' to '{device}'.") @@ -221,6 +235,9 @@ def to_device(self, device): ) ) + if self.config_obj.quantization: + model_kwargs["quantization_config"] = self.config_obj.quantization.to_bitsandbytes() + # we save and reload the weights to ensure that they can be sharded across the GPUs using `from_pretrained` with tempfile.TemporaryDirectory() as tmpdir: self.model.save_pretrained(tmpdir) diff --git a/ludwig/schema/llms/quantization.py b/ludwig/schema/llms/quantization.py new file mode 100644 index 00000000000..04eae3380b3 --- /dev/null +++ b/ludwig/schema/llms/quantization.py @@ -0,0 +1,89 @@ +import warnings + +from transformers import BitsAndBytesConfig + +from ludwig.api_annotations import DeveloperAPI +from ludwig.schema import utils as schema_utils +from ludwig.schema.utils import ludwig_dataclass + +warnings.filterwarnings( + action="ignore", + category=UserWarning, + module="bitsandbytes.cuda_setup.main", +) + + +@DeveloperAPI +@ludwig_dataclass +class QuantizationConfig(schema_utils.BaseMarshmallowConfig): + bits: int = schema_utils.IntegerOptions( + options=[4, 8], + default=4, + description="The quantization level to apply to weights on load.", + ) + + llm_int8_threshold: float = schema_utils.NonNegativeFloat( + default=6.0, + description=( + "This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit " + "Matrix Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339. Any hidden " + "states value that is above this threshold will be considered an outlier and the operation on those " + "values will be done in fp16. Values are usually normally distributed, that is, most values are in the " + "range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently " + "distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 " + "quantization works well for values of magnitude ~5, but beyond that, there is a significant performance " + "penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models " + "(small models, fine-tuning)." + ), + ) + + llm_int8_has_fp16_weight: bool = schema_utils.Boolean( + default=False, + description=( + "This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do " + "not have to be converted back and forth for the backward pass." + ), + ) + + bnb_4bit_compute_dtype: str = schema_utils.StringOptions( + options=["float32", "float16", "bfloat16"], + default="float16", + description=( + "This sets the computational type which might be different than the input type. For example, inputs " + "might be fp32, but computation can be set to bf16 for speedups." + ), + ) + + bnb_4bit_use_double_quant: bool = schema_utils.Boolean( + default=True, + description=( + "This flag is used for nested quantization where the quantization constants from the first quantization " + "are quantized again." + ), + ) + + bnb_4bit_quant_type: str = schema_utils.StringOptions( + options=["fp4", "nf4"], + default="nf4", + description="This sets the quantization data type in the bnb.nn.Linear4Bit layers.", + ) + + def to_bitsandbytes(self) -> BitsAndBytesConfig: + return BitsAndBytesConfig( + load_in_4bit=self.bits == 4, + load_in_8bit=self.bits == 8, + llm_int8_threshold=self.llm_int8_threshold, + llm_int8_has_fp16_weight=self.llm_int8_has_fp16_weight, + bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype, + bnb_4bit_use_double_quant=self.bnb_4bit_use_double_quant, + bnb_4bit_quant_type=self.bnb_4bit_quant_type, + ) + + +@DeveloperAPI +class QuantizationConfigField(schema_utils.DictMarshmallowField): + def __init__(self): + super().__init__(QuantizationConfig, default_missing=True) + + def _jsonschema_type_mapping(self): + return schema_utils.unload_jsonschema_from_marshmallow_class(QuantizationConfig) diff --git a/ludwig/schema/model_types/llm.py b/ludwig/schema/model_types/llm.py index 9bc1a1e8367..e9b735352c5 100644 --- a/ludwig/schema/model_types/llm.py +++ b/ludwig/schema/model_types/llm.py @@ -15,6 +15,7 @@ from ludwig.schema.llms.generation import LLMGenerationConfig, LLMGenerationConfigField from ludwig.schema.llms.peft import AdapterDataclassField, BaseAdapterConfig from ludwig.schema.llms.prompt import PromptConfig, PromptConfigField +from ludwig.schema.llms.quantization import QuantizationConfig, QuantizationConfigField from ludwig.schema.model_types.base import ModelConfig, register_model_type from ludwig.schema.preprocessing import PreprocessingConfig, PreprocessingField from ludwig.schema.trainer import LLMTrainerConfig, LLMTrainerDataclassField @@ -48,3 +49,4 @@ class LLMModelConfig(ModelConfig): generation: LLMGenerationConfig = LLMGenerationConfigField().get_default_field() adapter: Optional[BaseAdapterConfig] = AdapterDataclassField() + quantization: Optional[QuantizationConfig] = QuantizationConfigField().get_default_field() diff --git a/requirements_llm.txt b/requirements_llm.txt index 6983cd0c1c4..4e31e98ea0f 100644 --- a/requirements_llm.txt +++ b/requirements_llm.txt @@ -4,4 +4,4 @@ faiss-cpu accelerate loralib bitsandbytes -peft +peft>=0.4.0 diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 3c5108f556d..239b42eb981 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -294,9 +294,9 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_ ], ) @pytest.mark.parametrize( - "finetune_strategy,adapter_args", + "finetune_strategy,adapter_args,quantization", [ - (None, {}), + (None, {}, None), # ( # "prompt_tuning", # { @@ -315,9 +315,10 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_ # ("prefix_tuning", {"num_virtual_tokens": 8}), # ("p_tuning", {"num_virtual_tokens": 8, "encoder_reparameterization_type": "MLP"}), # ("p_tuning", {"num_virtual_tokens": 8, "encoder_reparameterization_type": "LSTM"}), - ("lora", {}), + ("lora", {}, None), + ("lora", {}, {"bits": 4}), # qlora # ("adalora", {}), - ("adaption_prompt", {"adapter_len": 6, "adapter_layers": 1}), + ("adaption_prompt", {"adapter_len": 6, "adapter_layers": 1}, None), ], ids=[ "none", @@ -327,11 +328,15 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_ # "p_tuning_mlp_reparameterization", # "p_tuning_lstm_reparameterization", "lora", + "qlora", # "adalora", "adaption_prompt", ], ) -def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strategy, adapter_args): +def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strategy, adapter_args, quantization): + if not torch.cuda.is_available() or torch.cuda.device_count() == 0: + pytest.skip("Skip: quantization requires GPU and none are available.") + input_features = [text_feature(name="input", encoder={"type": "passthrough"})] output_features = [text_feature(name="output")] @@ -364,6 +369,9 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat **adapter_args, } + if quantization is not None: + config["quantization"] = quantization + model = LudwigModel(config) model.train(dataset=df, output_directory=str(tmpdir), skip_save_processed_input=False) diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index 4b63982cc96..35cc0468db8 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -1,5 +1,6 @@ import os from tempfile import TemporaryDirectory +from typing import Optional import pytest import yaml @@ -42,6 +43,7 @@ from ludwig.schema.features.image_feature import AUGMENTATION_DEFAULT_OPERATIONS from ludwig.schema.features.number_feature import NumberOutputFeatureConfig from ludwig.schema.features.text_feature import TextOutputFeatureConfig +from ludwig.schema.llms.quantization import QuantizationConfig from ludwig.schema.model_config import ModelConfig from ludwig.schema.utils import BaseMarshmallowConfig, convert_submodules @@ -828,3 +830,28 @@ def test_llm_base_model_config_error(base_model_config): with pytest.raises(ConfigValidationError): ModelConfig.from_dict(config) + + +@pytest.mark.parametrize( + "bits,expected_qconfig", + [ + (None, None), + (4, QuantizationConfig(bits=4)), + (8, QuantizationConfig(bits=8)), + ], +) +def test_llm_quantization_config(bits: Optional[int], expected_qconfig: Optional[QuantizationConfig]): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "bigscience/bloomz-3b", + "quantization": {"bits": bits}, + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + } + + if bits is None: + del config["quantization"] + + config_obj = ModelConfig.from_dict(config) + + assert config_obj.quantization == expected_qconfig