Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into kto-trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 11, 2024
2 parents 1e1a5bb + baf3c1c commit 85b328f
Show file tree
Hide file tree
Showing 15 changed files with 430 additions and 110 deletions.
64 changes: 62 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ While training and evaluating we record the following reward metrics:
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards

### Accelerate DPO fine-tuning using `unsloth`
## Accelerate DPO fine-tuning using `unsloth`

You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `DPOTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:
Expand Down Expand Up @@ -156,6 +156,66 @@ dpo_trainer.train()

The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).

## Reference model considerations with PEFT

You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.

1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
2. Merge the adapter into the base model, create another adapter on top, then leave the `model_ref` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.

### Downsides to merging QLoRA before DPO (approach 2)

As suggested by [Tim Dettmers](https://twitter.com/Tim_Dettmers/status/1694654191325573456), the best option for merging QLoRA adapters is to first quantize the base model, merge the adapter, then convert back to bf16. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py)

You can also just merge the adapters the standard way without quantizing the base model, but then you have 1-2% reduced performance (and evidently, more issues with empty responses).

If you use the recommended approach, which quantizes the model, you're now in a situation where to use QLoRA for DPO, you will need to re-quantize the merged model again or use an unquantized merge with lower overall performance.

### Using option 3 - load the adapter twice

To avoid the downsides with option 2, at the expense of slightly increased VRAM, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer.

For example:
```python
# Load the base model.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False

# Load the adapter.
model = PeftModel.from_pretrained(
model,
"/path/to/peft",
is_trainable=True,
adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter("/path/to/peft", adapter_name="reference")

# Initialize the trainer, without a ref_model param.
dpo_trainer = DPOTrainer(
model,
...
model_adapter_name="train",
ref_adapter_name="reference",
)
```

## DPOTrainer

[[autodoc]] DPOTrainer
[[autodoc]] DPOTrainer
34 changes: 17 additions & 17 deletions docs/source/ppo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,22 @@ We can then loop over all examples in the dataset and generate a response for ea

```py
from tqdm import tqdm

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]

#### Get response from SFTModel
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

#### Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_model(texts)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
for batch in tqdm(ppo_trainer.dataloader):
query_tensors = batch["input_ids"]
#### Get response from SFTModel
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
#### Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_model(texts)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)

#### Save model
ppo_trainer.save_model("my_ppo_model")
Expand All @@ -148,4 +148,4 @@ While training and evaluating we log the following metrics:

[[autodoc]] PPOTrainer

[[autodoc]] PPOConfig
[[autodoc]] PPOConfig
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from setuptools import find_packages, setup


__version__ = "0.7.8.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.7.10.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)

REQUIRED_PKGS = [
"torch>=1.4.0",
Expand Down
29 changes: 27 additions & 2 deletions tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self):
self.instruction_template = "\n### User:"
self.response_template = "\n### Assistant:"

# GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25]
# GPT2Tokenizer: [198, 21017, 11787, 25] -> [21017, 11787, 25]
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901]
# Note: If this test is ever switched to Llama2Tokenizer, this should be double checked,
# and possibly switched back to [2:] instead of [1:].
# With GPT2Tokenizer, [1:] is correct - we want the 21017 token included, which is ###.
self.tokenized_instruction_w_context = self.tokenizer.encode(
self.instruction_template, add_special_tokens=False
)[2:]
)[1:]

# GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25]
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901]
Expand All @@ -57,6 +60,28 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self):
)
self.collator.torch_call([self.tokenized_instruction])

# Test for PR #1185
# We pass in a string where the first user template is different than the rest.
# Usually this would happen due to context-sensitive tokenization, but here we
# explicitly change the template to test the fix.
self.instruction = """## User: First instruction
### Assistant: First response
### User: Second instruction
### Assistant: Second response"""
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
self.collator = DataCollatorForCompletionOnlyLM(
self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer
)
collator_output = self.collator.torch_call([self.tokenized_instruction])
collator_text = self.tokenizer.decode(
collator_output["labels"][torch.where(collator_output["labels"] != -100)]
)
expected_text = " First response\n\n Second response" ""
self.assertEqual(collator_text, expected_text)

def test_data_collator_handling_of_long_sequences(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.instruction = """### System: You are a helpful assistant.
Expand Down
125 changes: 124 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from trl import DPOTrainer

from .testing_utils import require_no_wandb, require_peft
from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft


class DPOTrainerTester(unittest.TestCase):
Expand Down Expand Up @@ -313,3 +313,126 @@ def test_dpo_lora_save(self):
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
self.fail("Loading the saved peft adapter failed")

@require_peft
@require_bitsandbytes
@mark.peft_test
def test_dpo_lora_bf16_autocast_llama(self):
# Note this test only works on compute capability > 7 GPU devices
from peft import LoraConfig

model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
bf16=True,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
generate_during_eval=True,
)

# train the model
trainer.train()

# save peft adapter
trainer.save_model()

@parameterized.expand(
[
["gpt2", "sigmoid", False, False],
["gpt2", "sigmoid", False, True],
["gpt2", "sigmoid", True, False],
["gpt2", "sigmoid", True, True],
["gpt2", "ipo", False, False],
["gpt2", "ipo", False, True],
["gpt2", "ipo", True, False],
["gpt2", "ipo", True, True],
["gpt2", "kto_pair", False, False],
["gpt2", "kto_pair", False, True],
["gpt2", "kto_pair", True, False],
["gpt2", "kto_pair", True, True],
]
)
@require_bitsandbytes
@require_peft
@mark.peft_test
def test_dpo_lora_bf16_autocast(self, name, loss_type, pre_compute, gen_during_eval):
# Note this test only works on compute capability > 7 GPU devices
from peft import LoraConfig

lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
bf16=True,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
generate_during_eval=gen_during_eval,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute,
)

# train the model
trainer.train()

# save peft adapter
trainer.save_model()
2 changes: 1 addition & 1 deletion tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def test_loss_trainer(self):
logits = torch.exp(all_logprobs)
vpreds = values + 0.1

score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
score, non_score, kls = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask)

# just make sure a dummy loss is computed
Expand Down
28 changes: 16 additions & 12 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

import torch

from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
from trl import (
is_bitsandbytes_available,
is_diffusers_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)


def require_peft(test_case):
Expand All @@ -27,6 +33,15 @@ def require_peft(test_case):
return test_case


def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bnb. Skips the test if bnb is not available.
"""
if not is_bitsandbytes_available():
test_case = unittest.skip("test requires bnb")(test_case)
return test_case


def require_diffusers(test_case):
"""
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available.
Expand Down Expand Up @@ -55,17 +70,6 @@ def require_no_wandb(test_case):
return require_wandb(test_case, required=False)


def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available.
"""
try:
import bitsandbytes # noqa: F401
except ImportError:
test_case = unittest.skip("test requires bitsandbytes")(test_case)
return test_case


def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs.
Expand Down
3 changes: 2 additions & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# flake8: noqa

__version__ = "0.7.8.dev0"
__version__ = "0.7.10.dev0"

from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_bitsandbytes_available,
is_diffusers_available,
is_npu_available,
is_peft_available,
Expand Down
5 changes: 4 additions & 1 deletion trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def is_diffusers_available() -> bool:


def is_bitsandbytes_available() -> bool:
return importlib.util.find_spec("bitsandbytes") is not None
import torch

# bnb can be imported without GPU but is not usable.
return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available()


def is_torchvision_available() -> bool:
Expand Down
Loading

0 comments on commit 85b328f

Please sign in to comment.