Skip to content

Commit

Permalink
Add QLoRA for 4-bit fine-tuning (#3476)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 24, 2023
1 parent 6d417ef commit 48f7d1f
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 7 deletions.
22 changes: 22 additions & 0 deletions examples/llama2_7b_finetuning_4bit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Llama2-7b Fine-Tuning 4bit (QLoRA)

This example shows how to fine-tune [Llama2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) to follow instructions.
Instruction tuning is the first step in adapting a general purpose Large Language Model into a chatbot.

This example uses no distributed training or big data functionality. It is designed to run locally on any machine
with GPU availability.

## Prerequisites

- [HuggingFace API Token](https://huggingface.co/docs/hub/security-tokens)
- Access approval to [Llama2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- GPU with at least 12 GiB of VRAM (in our tests, we used an Nvidia T4)

## Running the example

Set your token environment variable from the terminal, then run the API script:

```bash
export HUGGING_FACE_HUB_TOKEN="<api_token>"
python train_alpaca.py
```
58 changes: 58 additions & 0 deletions examples/llama2_7b_finetuning_4bit/train_alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
import os

import yaml

from ludwig.api import LudwigModel

config = yaml.safe_load(
"""
model_type: llm
base_model: meta-llama/Llama-2-7b-hf
quantization:
bits: 4
adapter:
type: lora
input_features:
- name: instruction
type: text
output_features:
- name: output
type: text
trainer:
type: finetune
learning_rate: 0.0003
batch_size: 2
gradient_accumulation_steps: 8
epochs: 3
learning_rate_scheduler:
warmup_fraction: 0.01
backend:
type: local
"""
)

# Define Ludwig model object that drive model training
model = LudwigModel(config=config, logging_level=logging.INFO)

# initiate model training
(
train_stats, # dictionary containing training statistics
preprocessed_data, # tuple Ludwig Dataset objects of pre-processed training data
output_directory, # location of training results stored on disk
) = model.train(
dataset="ludwig://alpaca",
experiment_name="alpaca_instruct_4bit",
model_name="llama2_7b",
)

# list contents of output directory
print("contents of output directory:", output_directory)
for item in os.listdir(output_directory):
print("\t", item)
19 changes: 18 additions & 1 deletion ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ def __init__(
self.model_name = self.config_obj.base_model
self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model)

self.load_kwargs = {}
if self.config_obj.quantization:
# Apply quanitzation configuration at model load time
self.load_kwargs["torch_dtype"] = getattr(torch, self.config_obj.quantization.bnb_4bit_compute_dtype)
self.load_kwargs["quantization_config"] = self.config_obj.quantization.to_bitsandbytes()

logger.info("Loading large language model...")
self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model)
self.model = AutoModelForCausalLM.from_pretrained(self.config_obj.base_model, **self.load_kwargs)

# Model initially loaded onto cpu
self.curr_device = torch.device("cpu")
Expand Down Expand Up @@ -195,12 +201,20 @@ def initialize_adapter(self):

def prepare_for_training(self):
# TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this.
if self.config_obj.quantization:
self.prepare_for_quantized_training()
self.initialize_adapter()

def prepare_for_quantized_training(self):
from peft import prepare_model_for_kbit_training

self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=False)

def to_device(self, device):
device = torch.device(device)

if device == self.curr_device:
log_once(f"Model already on device'{device}'.")
return self
else:
log_once(f"Moving LLM from '{self.curr_device}' to '{device}'.")
Expand All @@ -221,6 +235,9 @@ def to_device(self, device):
)
)

if self.config_obj.quantization:
model_kwargs["quantization_config"] = self.config_obj.quantization.to_bitsandbytes()

# we save and reload the weights to ensure that they can be sharded across the GPUs using `from_pretrained`
with tempfile.TemporaryDirectory() as tmpdir:
self.model.save_pretrained(tmpdir)
Expand Down
89 changes: 89 additions & 0 deletions ludwig/schema/llms/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import warnings

from transformers import BitsAndBytesConfig

from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.utils import ludwig_dataclass

warnings.filterwarnings(
action="ignore",
category=UserWarning,
module="bitsandbytes.cuda_setup.main",
)


@DeveloperAPI
@ludwig_dataclass
class QuantizationConfig(schema_utils.BaseMarshmallowConfig):
bits: int = schema_utils.IntegerOptions(
options=[4, 8],
default=4,
description="The quantization level to apply to weights on load.",
)

llm_int8_threshold: float = schema_utils.NonNegativeFloat(
default=6.0,
description=(
"This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit "
"Matrix Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339. Any hidden "
"states value that is above this threshold will be considered an outlier and the operation on those "
"values will be done in fp16. Values are usually normally distributed, that is, most values are in the "
"range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently "
"distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 "
"quantization works well for values of magnitude ~5, but beyond that, there is a significant performance "
"penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models "
"(small models, fine-tuning)."
),
)

llm_int8_has_fp16_weight: bool = schema_utils.Boolean(
default=False,
description=(
"This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do "
"not have to be converted back and forth for the backward pass."
),
)

bnb_4bit_compute_dtype: str = schema_utils.StringOptions(
options=["float32", "float16", "bfloat16"],
default="float16",
description=(
"This sets the computational type which might be different than the input type. For example, inputs "
"might be fp32, but computation can be set to bf16 for speedups."
),
)

bnb_4bit_use_double_quant: bool = schema_utils.Boolean(
default=True,
description=(
"This flag is used for nested quantization where the quantization constants from the first quantization "
"are quantized again."
),
)

bnb_4bit_quant_type: str = schema_utils.StringOptions(
options=["fp4", "nf4"],
default="nf4",
description="This sets the quantization data type in the bnb.nn.Linear4Bit layers.",
)

def to_bitsandbytes(self) -> BitsAndBytesConfig:
return BitsAndBytesConfig(
load_in_4bit=self.bits == 4,
load_in_8bit=self.bits == 8,
llm_int8_threshold=self.llm_int8_threshold,
llm_int8_has_fp16_weight=self.llm_int8_has_fp16_weight,
bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype,
bnb_4bit_use_double_quant=self.bnb_4bit_use_double_quant,
bnb_4bit_quant_type=self.bnb_4bit_quant_type,
)


@DeveloperAPI
class QuantizationConfigField(schema_utils.DictMarshmallowField):
def __init__(self):
super().__init__(QuantizationConfig, default_missing=True)

def _jsonschema_type_mapping(self):
return schema_utils.unload_jsonschema_from_marshmallow_class(QuantizationConfig)
2 changes: 2 additions & 0 deletions ludwig/schema/model_types/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ludwig.schema.llms.generation import LLMGenerationConfig, LLMGenerationConfigField
from ludwig.schema.llms.peft import AdapterDataclassField, BaseAdapterConfig
from ludwig.schema.llms.prompt import PromptConfig, PromptConfigField
from ludwig.schema.llms.quantization import QuantizationConfig, QuantizationConfigField
from ludwig.schema.model_types.base import ModelConfig, register_model_type
from ludwig.schema.preprocessing import PreprocessingConfig, PreprocessingField
from ludwig.schema.trainer import LLMTrainerConfig, LLMTrainerDataclassField
Expand Down Expand Up @@ -48,3 +49,4 @@ class LLMModelConfig(ModelConfig):
generation: LLMGenerationConfig = LLMGenerationConfigField().get_default_field()

adapter: Optional[BaseAdapterConfig] = AdapterDataclassField()
quantization: Optional[QuantizationConfig] = QuantizationConfigField().get_default_field()
2 changes: 1 addition & 1 deletion requirements_llm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ faiss-cpu
accelerate
loralib
bitsandbytes
peft
peft>=0.4.0
18 changes: 13 additions & 5 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_
],
)
@pytest.mark.parametrize(
"finetune_strategy,adapter_args",
"finetune_strategy,adapter_args,quantization",
[
(None, {}),
(None, {}, None),
# (
# "prompt_tuning",
# {
Expand All @@ -315,9 +315,10 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_
# ("prefix_tuning", {"num_virtual_tokens": 8}),
# ("p_tuning", {"num_virtual_tokens": 8, "encoder_reparameterization_type": "MLP"}),
# ("p_tuning", {"num_virtual_tokens": 8, "encoder_reparameterization_type": "LSTM"}),
("lora", {}),
("lora", {}, None),
("lora", {}, {"bits": 4}), # qlora
# ("adalora", {}),
("adaption_prompt", {"adapter_len": 6, "adapter_layers": 1}),
("adaption_prompt", {"adapter_len": 6, "adapter_layers": 1}, None),
],
ids=[
"none",
Expand All @@ -327,11 +328,15 @@ def test_llm_few_shot_classification(tmpdir, backend, csv_filename, ray_cluster_
# "p_tuning_mlp_reparameterization",
# "p_tuning_lstm_reparameterization",
"lora",
"qlora",
# "adalora",
"adaption_prompt",
],
)
def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strategy, adapter_args):
def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strategy, adapter_args, quantization):
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
pytest.skip("Skip: quantization requires GPU and none are available.")

input_features = [text_feature(name="input", encoder={"type": "passthrough"})]
output_features = [text_feature(name="output")]

Expand Down Expand Up @@ -364,6 +369,9 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat
**adapter_args,
}

if quantization is not None:
config["quantization"] = quantization

model = LudwigModel(config)
model.train(dataset=df, output_directory=str(tmpdir), skip_save_processed_input=False)

Expand Down
27 changes: 27 additions & 0 deletions tests/ludwig/schema/test_model_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from tempfile import TemporaryDirectory
from typing import Optional

import pytest
import yaml
Expand Down Expand Up @@ -42,6 +43,7 @@
from ludwig.schema.features.image_feature import AUGMENTATION_DEFAULT_OPERATIONS
from ludwig.schema.features.number_feature import NumberOutputFeatureConfig
from ludwig.schema.features.text_feature import TextOutputFeatureConfig
from ludwig.schema.llms.quantization import QuantizationConfig
from ludwig.schema.model_config import ModelConfig
from ludwig.schema.utils import BaseMarshmallowConfig, convert_submodules

Expand Down Expand Up @@ -828,3 +830,28 @@ def test_llm_base_model_config_error(base_model_config):

with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)


@pytest.mark.parametrize(
"bits,expected_qconfig",
[
(None, None),
(4, QuantizationConfig(bits=4)),
(8, QuantizationConfig(bits=8)),
],
)
def test_llm_quantization_config(bits: Optional[int], expected_qconfig: Optional[QuantizationConfig]):
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "bigscience/bloomz-3b",
"quantization": {"bits": bits},
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
}

if bits is None:
del config["quantization"]

config_obj = ModelConfig.from_dict(config)

assert config_obj.quantization == expected_qconfig

0 comments on commit 48f7d1f

Please sign in to comment.