-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add QLoRA for 4-bit fine-tuning (#3476)
- Loading branch information
Showing
8 changed files
with
230 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Llama2-7b Fine-Tuning 4bit (QLoRA) | ||
|
||
This example shows how to fine-tune [Llama2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) to follow instructions. | ||
Instruction tuning is the first step in adapting a general purpose Large Language Model into a chatbot. | ||
|
||
This example uses no distributed training or big data functionality. It is designed to run locally on any machine | ||
with GPU availability. | ||
|
||
## Prerequisites | ||
|
||
- [HuggingFace API Token](https://huggingface.co/docs/hub/security-tokens) | ||
- Access approval to [Llama2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | ||
- GPU with at least 12 GiB of VRAM (in our tests, we used an Nvidia T4) | ||
|
||
## Running the example | ||
|
||
Set your token environment variable from the terminal, then run the API script: | ||
|
||
```bash | ||
export HUGGING_FACE_HUB_TOKEN="<api_token>" | ||
python train_alpaca.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import logging | ||
import os | ||
|
||
import yaml | ||
|
||
from ludwig.api import LudwigModel | ||
|
||
config = yaml.safe_load( | ||
""" | ||
model_type: llm | ||
base_model: meta-llama/Llama-2-7b-hf | ||
quantization: | ||
bits: 4 | ||
adapter: | ||
type: lora | ||
input_features: | ||
- name: instruction | ||
type: text | ||
output_features: | ||
- name: output | ||
type: text | ||
trainer: | ||
type: finetune | ||
learning_rate: 0.0003 | ||
batch_size: 2 | ||
gradient_accumulation_steps: 8 | ||
epochs: 3 | ||
learning_rate_scheduler: | ||
warmup_fraction: 0.01 | ||
backend: | ||
type: local | ||
""" | ||
) | ||
|
||
# Define Ludwig model object that drive model training | ||
model = LudwigModel(config=config, logging_level=logging.INFO) | ||
|
||
# initiate model training | ||
( | ||
train_stats, # dictionary containing training statistics | ||
preprocessed_data, # tuple Ludwig Dataset objects of pre-processed training data | ||
output_directory, # location of training results stored on disk | ||
) = model.train( | ||
dataset="ludwig://alpaca", | ||
experiment_name="alpaca_instruct_4bit", | ||
model_name="llama2_7b", | ||
) | ||
|
||
# list contents of output directory | ||
print("contents of output directory:", output_directory) | ||
for item in os.listdir(output_directory): | ||
print("\t", item) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import warnings | ||
|
||
from transformers import BitsAndBytesConfig | ||
|
||
from ludwig.api_annotations import DeveloperAPI | ||
from ludwig.schema import utils as schema_utils | ||
from ludwig.schema.utils import ludwig_dataclass | ||
|
||
warnings.filterwarnings( | ||
action="ignore", | ||
category=UserWarning, | ||
module="bitsandbytes.cuda_setup.main", | ||
) | ||
|
||
|
||
@DeveloperAPI | ||
@ludwig_dataclass | ||
class QuantizationConfig(schema_utils.BaseMarshmallowConfig): | ||
bits: int = schema_utils.IntegerOptions( | ||
options=[4, 8], | ||
default=4, | ||
description="The quantization level to apply to weights on load.", | ||
) | ||
|
||
llm_int8_threshold: float = schema_utils.NonNegativeFloat( | ||
default=6.0, | ||
description=( | ||
"This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit " | ||
"Matrix Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339. Any hidden " | ||
"states value that is above this threshold will be considered an outlier and the operation on those " | ||
"values will be done in fp16. Values are usually normally distributed, that is, most values are in the " | ||
"range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently " | ||
"distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 " | ||
"quantization works well for values of magnitude ~5, but beyond that, there is a significant performance " | ||
"penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models " | ||
"(small models, fine-tuning)." | ||
), | ||
) | ||
|
||
llm_int8_has_fp16_weight: bool = schema_utils.Boolean( | ||
default=False, | ||
description=( | ||
"This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do " | ||
"not have to be converted back and forth for the backward pass." | ||
), | ||
) | ||
|
||
bnb_4bit_compute_dtype: str = schema_utils.StringOptions( | ||
options=["float32", "float16", "bfloat16"], | ||
default="float16", | ||
description=( | ||
"This sets the computational type which might be different than the input type. For example, inputs " | ||
"might be fp32, but computation can be set to bf16 for speedups." | ||
), | ||
) | ||
|
||
bnb_4bit_use_double_quant: bool = schema_utils.Boolean( | ||
default=True, | ||
description=( | ||
"This flag is used for nested quantization where the quantization constants from the first quantization " | ||
"are quantized again." | ||
), | ||
) | ||
|
||
bnb_4bit_quant_type: str = schema_utils.StringOptions( | ||
options=["fp4", "nf4"], | ||
default="nf4", | ||
description="This sets the quantization data type in the bnb.nn.Linear4Bit layers.", | ||
) | ||
|
||
def to_bitsandbytes(self) -> BitsAndBytesConfig: | ||
return BitsAndBytesConfig( | ||
load_in_4bit=self.bits == 4, | ||
load_in_8bit=self.bits == 8, | ||
llm_int8_threshold=self.llm_int8_threshold, | ||
llm_int8_has_fp16_weight=self.llm_int8_has_fp16_weight, | ||
bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype, | ||
bnb_4bit_use_double_quant=self.bnb_4bit_use_double_quant, | ||
bnb_4bit_quant_type=self.bnb_4bit_quant_type, | ||
) | ||
|
||
|
||
@DeveloperAPI | ||
class QuantizationConfigField(schema_utils.DictMarshmallowField): | ||
def __init__(self): | ||
super().__init__(QuantizationConfig, default_missing=True) | ||
|
||
def _jsonschema_type_mapping(self): | ||
return schema_utils.unload_jsonschema_from_marshmallow_class(QuantizationConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,4 @@ faiss-cpu | |
accelerate | ||
loralib | ||
bitsandbytes | ||
peft | ||
peft>=0.4.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters