From d57d0f9ca46a63d370b91791352edda0154576f5 Mon Sep 17 00:00:00 2001 From: maneandrea <38288005+maneandrea@users.noreply.github.com> Date: Sun, 7 Jan 2024 21:43:34 -0700 Subject: [PATCH 01/17] Address issue #1122 (#1174) * Address issue #1122 Issue [#1122](https://github.com/huggingface/trl/issues/1122) takes care of an inconsistency between `_prepare_packed_dataloader` and `_prepare_non_packed_dataloader` * made attention_mask field in ConstantLengthDataset a tensor --- trl/trainer/sft_trainer.py | 6 +++++- trl/trainer/utils.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index b46d4d3d76..4061dfe88b 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -402,7 +402,11 @@ def tokenize(element): else: self._dataset_sanity_checked = True - return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + return { + "input_ids": outputs["input_ids"], + "labels": outputs["input_ids"], + "attention_mask": outputs["attention_mask"], + } tokenized_dataset = dataset.map( tokenize, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index ca09318289..e7ac8225df 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -452,6 +452,7 @@ def __iter__(self): yield { "input_ids": torch.LongTensor(example), "labels": torch.LongTensor(example), + "attention_mask": torch.ones(len(example)), } From ad597dbcb39e1f8531de46af098a2313a3a70a5d Mon Sep 17 00:00:00 2001 From: Jfhseh <67591670+Jfhseh@users.noreply.github.com> Date: Sun, 7 Jan 2024 23:50:00 -0500 Subject: [PATCH 02/17] Fix misleading variable "epoch" from the training loop from PPOTrainer Doc. (#1171) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix misleading variable "epoch" from PPOTrainer Doc. The usage of the variable “epoch” is misleading in the original Doc, the dataloader does not contain the data for ALL epochs, but 1 only, thus "for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader))" is misleading and does not actually stores the epoch #. The correct version comes from the TRL PPO notebook tutorial (https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment-control.ipynb), which uses an outer loop to capture the epochs. I posted also the question on forum: https://discuss.huggingface.co/t/confusing-and-possibly-misleading-ppo-trainer-code-from-trl-api-doc-tutorial/67531 * Remove batch_id --- docs/source/ppo_trainer.mdx | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/source/ppo_trainer.mdx b/docs/source/ppo_trainer.mdx index 0c86f3b912..14484d14e6 100644 --- a/docs/source/ppo_trainer.mdx +++ b/docs/source/ppo_trainer.mdx @@ -115,22 +115,22 @@ We can then loop over all examples in the dataset and generate a response for ea ```py from tqdm import tqdm - -for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): - query_tensors = batch["input_ids"] - - #### Get response from SFTModel - response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) - batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] - - #### Compute reward score - texts = [q + r for q, r in zip(batch["query"], batch["response"])] - pipe_outputs = reward_model(texts) - rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] - - #### Run PPO step - stats = ppo_trainer.step(query_tensors, response_tensors, rewards) - ppo_trainer.log_stats(stats, batch, rewards) +for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "): + for batch in tqdm(ppo_trainer.dataloader): + query_tensors = batch["input_ids"] + + #### Get response from SFTModel + response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) + batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + #### Compute reward score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_model(texts) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + + #### Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) #### Save model ppo_trainer.save_model("my_ppo_model") @@ -148,4 +148,4 @@ While training and evaluating we log the following metrics: [[autodoc]] PPOTrainer -[[autodoc]] PPOConfig \ No newline at end of file +[[autodoc]] PPOConfig From 104a02d207b63a4a062882aaff68f2d275493399 Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Sun, 7 Jan 2024 21:09:10 -0800 Subject: [PATCH 03/17] SFTTrainer: follow args.remove_unused_columns (#1188) --- trl/trainer/sft_trainer.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 4061dfe88b..8eed33b241 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -258,6 +258,7 @@ def make_inputs_require_grad(module, input, output): formatting_func, num_of_sequences, chars_per_token, + remove_unused_columns=args.remove_unused_columns if args is not None else True, **dataset_kwargs, ) if eval_dataset is not None: @@ -273,6 +274,7 @@ def make_inputs_require_grad(module, input, output): formatting_func, num_of_sequences, chars_per_token, + remove_unused_columns=args.remove_unused_columns if args is not None else True, **dataset_kwargs, ) if not _multiple: @@ -348,6 +350,7 @@ def _prepare_dataset( formatting_func, num_of_sequences, chars_per_token, + remove_unused_columns=True, append_concat_token=True, add_special_tokens=True, ): @@ -360,7 +363,13 @@ def _prepare_dataset( if not packing: return self._prepare_non_packed_dataloader( - tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func, add_special_tokens + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + formatting_func, + add_special_tokens, + remove_unused_columns, ) else: @@ -377,7 +386,14 @@ def _prepare_dataset( ) def _prepare_non_packed_dataloader( - self, tokenizer, dataset, dataset_text_field, max_seq_length, formatting_func=None, add_special_tokens=True + self, + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + formatting_func=None, + add_special_tokens=True, + remove_unused_columns=True, ): use_formatting_func = formatting_func is not None and dataset_text_field is None self._dataset_sanity_checked = False @@ -411,7 +427,7 @@ def tokenize(element): tokenized_dataset = dataset.map( tokenize, batched=True, - remove_columns=dataset.column_names, + remove_columns=dataset.column_names if remove_unused_columns else None, num_proc=self.dataset_num_proc, batch_size=self.dataset_batch_size, ) From d5910b0ff50e60fd89b278c36eef6d07485bcdf0 Mon Sep 17 00:00:00 2001 From: Pablo Vicente Date: Mon, 8 Jan 2024 09:15:53 +0100 Subject: [PATCH 04/17] Handle last token from generation prompt (#1153) * Handle last token from generation prompt * Remove prints * Reformat dpo_trainer file --- trl/trainer/dpo_trainer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8e7555d86d..ba3bb3a32e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -593,6 +593,29 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) raise ValueError(f"rejected should be an str but got {type(rejected)}") rejected_tokens = self.build_tokenized_answer(prompt, rejected) + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])] + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + # add BOS token to head of prompt prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] From dbcb2f00217a24f23d13bedf1ca495ea6af672a7 Mon Sep 17 00:00:00 2001 From: Jon Durbin Date: Mon, 8 Jan 2024 04:26:40 -0500 Subject: [PATCH 05/17] Allow separate devices for target/ref models. (#1190) * Allow separate devices for target/ref models. * Remove original/duplicate. * Cleanup original, black formatting. --------- Co-authored-by: Jon Durbin --- trl/trainer/dpo_trainer.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index ba3bb3a32e..6ab8fba5a4 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -818,6 +818,8 @@ def dpo_loss( else: ref_logratios = reference_chosen_logps - reference_rejected_logps + pi_logratios = pi_logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) logits = pi_logratios - ref_logratios # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. @@ -853,8 +855,19 @@ def dpo_loss( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']" ) - chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() - rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + chosen_rewards = ( + self.beta + * ( + policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device) + ).detach() + ) + rejected_rewards = ( + self.beta + * ( + policy_rejected_logps.to(self.accelerator.device) + - reference_rejected_logps.to(self.accelerator.device) + ).detach() + ) return losses, chosen_rewards, rejected_rewards From 3267be0fcd424c8d224d7f28a45a64358c857ed0 Mon Sep 17 00:00:00 2001 From: Jon Durbin Date: Mon, 8 Jan 2024 10:12:45 -0500 Subject: [PATCH 06/17] Allow swapping PEFT adapters for target/ref model. (#1193) * Allow swapping PEFT adapters for target/ref model. * Update DPOTrainer docs. * python format * isort * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul * Update docs/source/dpo_trainer.mdx Co-authored-by: Kashif Rasul --------- Co-authored-by: Kashif Rasul --- docs/source/dpo_trainer.mdx | 64 +++++++++++++++++++++++++++++++++++-- trl/trainer/dpo_trainer.py | 30 +++++++++++++---- 2 files changed, 86 insertions(+), 8 deletions(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index e19fb1915b..422361c826 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -101,7 +101,7 @@ While training and evaluating we record the following reward metrics: * `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards * `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards -### Accelerate DPO fine-tuning using `unsloth` +## Accelerate DPO fine-tuning using `unsloth` You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `DPOTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures. First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows: @@ -156,6 +156,66 @@ dpo_trainer.train() The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). +## Reference model considerations with PEFT + +You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA. + +1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient. +2. Merge the adapter into the base model, create another adapter on top, then leave the `model_ref` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below. +3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls. + +### Downsides to merging QLoRA before DPO (approach 2) + +As suggested by [Tim Dettmers](https://twitter.com/Tim_Dettmers/status/1694654191325573456), the best option for merging QLoRA adapters is to first quantize the base model, merge the adapter, then convert back to bf16. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py) + +You can also just merge the adapters the standard way without quantizing the base model, but then you have 1-2% reduced performance (and evidently, more issues with empty responses). + +If you use the recommended approach, which quantizes the model, you're now in a situation where to use QLoRA for DPO, you will need to re-quantize the merged model again or use an unquantized merge with lower overall performance. + +### Using option 3 - load the adapter twice + +To avoid the downsides with option 2, at the expense of slightly increased VRAM, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer. + +For example: +```python +# Load the base model. +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", +) +model = AutoModelForCausalLM.from_pretrained( + "mistralai/mixtral-8x7b-v0.1", + load_in_4bit=True, + quantization_config=bnb_config, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + device_map="auto", +) +model.config.use_cache = False + +# Load the adapter. +model = PeftModel.from_pretrained( + model, + "/path/to/peft", + is_trainable=True, + adapter_name="train", +) +# Load the adapter a second time, with a different name, which will be our reference model. +model.load_adapter("/path/to/peft", adapter_name="reference") + +# Initialize the trainer, without a ref_model param. +dpo_trainer = DPOTrainer( + model, + ... + model_adapter_name="train", + ref_adapter_name="reference", +) +``` + ## DPOTrainer -[[autodoc]] DPOTrainer \ No newline at end of file +[[autodoc]] DPOTrainer diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 6ab8fba5a4..aa1e5efa35 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -16,7 +16,7 @@ import random import warnings from collections import defaultdict -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from copy import deepcopy from functools import wraps from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -126,6 +126,10 @@ class DPOTrainer(Trainer): Dict of Optional kwargs to pass when instantiating the model from a string ref_model_init_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when instantiating the ref model from a string + model_adapter_name (`str`, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. """ _tag_names = ["trl", "dpo"] @@ -160,6 +164,8 @@ def __init__( precompute_ref_log_probs: bool = False, model_init_kwargs: Optional[Dict] = None, ref_model_init_kwargs: Optional[Dict] = None, + model_adapter_name: str = None, + ref_adapter_name: str = None, ): if model_init_kwargs is None: model_init_kwargs = {} @@ -253,6 +259,8 @@ def make_inputs_require_grad(module, input, output): self.is_encoder_decoder = is_encoder_decoder self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name if ref_model: self.ref_model = ref_model @@ -704,14 +712,24 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) return batch + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with self.accelerator.unwrap_model( + self.model + ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" # compute reference logps with torch.no_grad(): if self.ref_model is None: - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model else nullcontext(): + with self.null_ref_context(): ( reference_chosen_logps, reference_rejected_logps, @@ -976,7 +994,7 @@ def get_batch_loss_metrics( else: with torch.no_grad(): if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).disable_adapter(): + with self.null_ref_context(): ( reference_chosen_logps, reference_rejected_logps, @@ -1048,7 +1066,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[ reference_output = batch["reference_output"] else: if self.ref_model is None: - with self.accelerator.unwrap_model(self.model).disable_adapter(): + with self.null_ref_context(): reference_output = self.model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], From 384b868fe65b0b92342dc9a514dbca25bb1678a7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Jan 2024 05:13:26 +0100 Subject: [PATCH 07/17] Release: v0.7.8 (#1200) --- setup.py | 2 +- trl/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 0f6394768b..4d26590e88 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ from setuptools import find_packages, setup -__version__ = "0.7.8.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.7.8" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "torch>=1.4.0", diff --git a/trl/__init__.py b/trl/__init__.py index 3331e64327..0445535a37 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa -__version__ = "0.7.8.dev0" +__version__ = "0.7.8" from .core import set_seed from .environment import TextEnvironment, TextHistory From b21ed0ddbc7bab5acd21c82376d5d1fe9b360239 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Jan 2024 05:19:10 +0100 Subject: [PATCH 08/17] set dev version (#1201) --- setup.py | 2 +- trl/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 4d26590e88..102701a774 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ from setuptools import find_packages, setup -__version__ = "0.7.8" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.7.9.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "torch>=1.4.0", diff --git a/trl/__init__.py b/trl/__init__.py index 0445535a37..4d2211d0ab 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa -__version__ = "0.7.8" +__version__ = "0.7.9.dev0" from .core import set_seed from .environment import TextEnvironment, TextHistory From 4ae35afdd65cec55321455eb37fc072f5d438a4a Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Mon, 8 Jan 2024 21:41:53 -0800 Subject: [PATCH 09/17] Fix instruction token masking (#1185) * Fix instruction token masking Fix instruction token masking if the first instruction is tokenized differently than the others, or in general if no instruction is detected before the first response. * Bugfix for edge case (in case either of the templates isn't found at all, ...idxs[0] might not exist) * Add test for instruction masking fix --- tests/test_data_collator_completion_only.py | 29 +++++++++++++++++++-- trl/trainer/utils.py | 7 +++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/test_data_collator_completion_only.py b/tests/test_data_collator_completion_only.py index c895a616e1..544230fe26 100644 --- a/tests/test_data_collator_completion_only.py +++ b/tests/test_data_collator_completion_only.py @@ -31,11 +31,14 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self): self.instruction_template = "\n### User:" self.response_template = "\n### Assistant:" - # GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25] + # GPT2Tokenizer: [198, 21017, 11787, 25] -> [21017, 11787, 25] # Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901] + # Note: If this test is ever switched to Llama2Tokenizer, this should be double checked, + # and possibly switched back to [2:] instead of [1:]. + # With GPT2Tokenizer, [1:] is correct - we want the 21017 token included, which is ###. self.tokenized_instruction_w_context = self.tokenizer.encode( self.instruction_template, add_special_tokens=False - )[2:] + )[1:] # GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25] # Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901] @@ -57,6 +60,28 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self): ) self.collator.torch_call([self.tokenized_instruction]) + # Test for PR #1185 + # We pass in a string where the first user template is different than the rest. + # Usually this would happen due to context-sensitive tokenization, but here we + # explicitly change the template to test the fix. + self.instruction = """## User: First instruction + +### Assistant: First response + +### User: Second instruction + +### Assistant: Second response""" + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + self.collator = DataCollatorForCompletionOnlyLM( + self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer + ) + collator_output = self.collator.torch_call([self.tokenized_instruction]) + collator_text = self.tokenizer.decode( + collator_output["labels"][torch.where(collator_output["labels"] != -100)] + ) + expected_text = " First response\n\n Second response" "" + self.assertEqual(collator_text, expected_text) + def test_data_collator_handling_of_long_sequences(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") self.instruction = """### System: You are a helpful assistant. diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index e7ac8225df..0b2c2cb5e4 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -176,6 +176,13 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D ) batch["labels"][i, :] = self.ignore_index + if ( + len(human_token_ids_idxs) > 0 + and len(response_token_ids_idxs) > 0 + and human_token_ids_idxs[0] > response_token_ids_idxs[0] + ): + human_token_ids_idxs = [0] + human_token_ids_idxs + for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): # Make pytorch loss function ignore all non response tokens if idx != 0: From a236c5750f5b9f470191355726aed406f8bfff18 Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Mon, 8 Jan 2024 21:48:25 -0800 Subject: [PATCH 10/17] Fix reported KL in PPO trainer (#1180) * Fix reported KL in PPO trainer previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc). Now we return the actual KL calculated in `compute_rewards()`, and report that. * fix test --- tests/test_ppo_trainer.py | 2 +- trl/trainer/ppo_trainer.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 0af091fc3c..b5b31f512e 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -579,7 +579,7 @@ def test_loss_trainer(self): logits = torch.exp(all_logprobs) vpreds = values + 0.1 - score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask) + score, non_score, kls = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask) values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask) # just make sure a dummy loss is computed diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 3194ff292c..e64973c7be 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -733,11 +733,11 @@ def step( active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False) ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) - rewards, non_score_reward = self.compute_rewards( + rewards, non_score_reward, kls = self.compute_rewards( scores, active_full_logprobs, ref_full_logprobs, masks ) else: - rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) + rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) timing["time/ppo/compute_rewards"] = time.time() - t t = time.time() @@ -831,6 +831,7 @@ def step( masks=masks, queries=queries, responses=responses, + kls=kls, ) # Gather/Reduce stats from all processes if self.is_distributed: @@ -1091,11 +1092,17 @@ def compute_rewards( Log probabilities of the model, shape (`batch_size`, `response_length`) ref_logprobs (`torch.FloatTensor`): Log probabilities of the reference model, shape (`batch_size`, `response_length`) + + Returns: + `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`) + `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`) + `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`) """ - rewards, non_score_rewards = [], [] + rewards, non_score_rewards, kls = [], [], [] for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks): # compute KL penalty (from difference in logprobs) kl = self._kl_penalty(logprob, ref_logprob) + kls.append(kl) non_score_reward = -self.kl_ctl.value * kl non_score_rewards.append(non_score_reward) reward = non_score_reward.clone() @@ -1104,7 +1111,7 @@ def compute_rewards( # reward is preference model score + KL penalty reward[last_non_masked_index] += score rewards.append(reward) - return torch.stack(rewards), torch.stack(non_score_rewards) + return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls) def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor: if self.config.kl_penalty == "kl": @@ -1256,7 +1263,8 @@ def record_step_stats(self, kl_coef: float, **data): """ mask = data.pop("masks") - kl_list = ((data["logprobs"] - data["ref_logprobs"]) * mask).sum(axis=-1) + kls = data.pop("kls") + kl_list = ((kls) * mask).sum(axis=-1) mean_kl = kl_list.mean() mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean() From d116887ed4321467844f57752f9d7dce1640ee5c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Jan 2024 09:35:50 +0100 Subject: [PATCH 11/17] [`DPOTrainer`] Fix peft + DPO + bf16 if one uses `generate_during_eval` or pre-computed logits (#1203) * fix peft + DPO + bf16 * fix * revert old behaviour * fix tests * fix * fix * fix * fix --- tests/test_dpo_trainer.py | 125 ++++++++++++++++++++++++++++++++++++- tests/testing_utils.py | 28 +++++---- trl/__init__.py | 1 + trl/import_utils.py | 5 +- trl/trainer/dpo_trainer.py | 69 ++++++++++++-------- trl/trainer/utils.py | 4 +- 6 files changed, 191 insertions(+), 41 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index e60fd70220..4c0b1d1412 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -22,7 +22,7 @@ from trl import DPOTrainer -from .testing_utils import require_no_wandb, require_peft +from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft class DPOTrainerTester(unittest.TestCase): @@ -313,3 +313,126 @@ def test_dpo_lora_save(self): AutoModelForCausalLM.from_pretrained(tmp_dir) except OSError: self.fail("Loading the saved peft adapter failed") + + @require_peft + @require_bitsandbytes + @mark.peft_test + def test_dpo_lora_bf16_autocast_llama(self): + # Note this test only works on compute capability > 7 GPU devices + from peft import LoraConfig + + model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + bf16=True, + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=lora_config, + generate_during_eval=True, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + @parameterized.expand( + [ + ["gpt2", "sigmoid", False, False], + ["gpt2", "sigmoid", False, True], + ["gpt2", "sigmoid", True, False], + ["gpt2", "sigmoid", True, True], + ["gpt2", "ipo", False, False], + ["gpt2", "ipo", False, True], + ["gpt2", "ipo", True, False], + ["gpt2", "ipo", True, True], + ["gpt2", "kto_pair", False, False], + ["gpt2", "kto_pair", False, True], + ["gpt2", "kto_pair", True, False], + ["gpt2", "kto_pair", True, True], + ] + ) + @require_bitsandbytes + @require_peft + @mark.peft_test + def test_dpo_lora_bf16_autocast(self, name, loss_type, pre_compute, gen_during_eval): + # Note this test only works on compute capability > 7 GPU devices + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + bf16=True, + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + beta=0.1, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=lora_config, + generate_during_eval=gen_during_eval, + loss_type=loss_type, + precompute_ref_log_probs=pre_compute, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() diff --git a/tests/testing_utils.py b/tests/testing_utils.py index f3988de4c9..96e22568b5 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -15,7 +15,13 @@ import torch -from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available +from trl import ( + is_bitsandbytes_available, + is_diffusers_available, + is_peft_available, + is_wandb_available, + is_xpu_available, +) def require_peft(test_case): @@ -27,6 +33,15 @@ def require_peft(test_case): return test_case +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bnb. Skips the test if bnb is not available. + """ + if not is_bitsandbytes_available(): + test_case = unittest.skip("test requires bnb")(test_case) + return test_case + + def require_diffusers(test_case): """ Decorator marking a test that requires diffusers. Skips the test if diffusers is not available. @@ -55,17 +70,6 @@ def require_no_wandb(test_case): return require_wandb(test_case, required=False) -def require_bitsandbytes(test_case): - """ - Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. - """ - try: - import bitsandbytes # noqa: F401 - except ImportError: - test_case = unittest.skip("test requires bitsandbytes")(test_case) - return test_case - - def require_torch_multi_gpu(test_case): """ Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs. diff --git a/trl/__init__.py b/trl/__init__.py index 4d2211d0ab..da7ac999c8 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -6,6 +6,7 @@ from .environment import TextEnvironment, TextHistory from .extras import BestOfNSampler from .import_utils import ( + is_bitsandbytes_available, is_diffusers_available, is_npu_available, is_peft_available, diff --git a/trl/import_utils.py b/trl/import_utils.py index 6d2388ec84..88a04f7d15 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -63,7 +63,10 @@ def is_diffusers_available() -> bool: def is_bitsandbytes_available() -> bool: - return importlib.util.find_spec("bitsandbytes") is not None + import torch + + # bnb can be imported without GPU but is not usable. + return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available() def is_torchvision_available() -> bool: diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index aa1e5efa35..2b2f51049c 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -193,6 +193,10 @@ def __init__( ) ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + if not is_peft_available() and peft_config is not None: raise ValueError( "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" @@ -230,6 +234,8 @@ def make_inputs_require_grad(module, input, output): model = get_peft_model(model, peft_config) if args.bf16 and getattr(model, "is_loaded_in_4bit", False): peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True # For models that use gradient_checkpoiting, we need to attach a hook that enables input # to explicitly have `requires_grad=True`, otherwise training will either silently @@ -726,8 +732,10 @@ def null_ref_context(self): def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" + compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + # compute reference logps - with torch.no_grad(): + with torch.no_grad(), compte_ref_context_manager(): if self.ref_model is None: with self.null_ref_context(): ( @@ -1040,7 +1048,11 @@ def compute_loss( "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" ) - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with compute_loss_context_manager(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") # force log the metrics if self.accelerator.is_main_process: @@ -1053,35 +1065,40 @@ def compute_loss( def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.tokenizer.pad_token_id, - ) + # If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with + # the torch cuda amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast + + with generate_context_manager(): + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) - # if reference_output in batch use that otherwise use the reference model - if "reference_output" in batch: - reference_output = batch["reference_output"] - else: - if self.ref_model is None: - with self.null_ref_context(): - reference_output = self.model.generate( + # if reference_output in batch use that otherwise use the reference model + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, ) - else: - reference_output = self.ref_model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.tokenizer.pad_token_id, - ) policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) @@ -1109,7 +1126,9 @@ def prediction_step( else: ignore_keys = [] - with torch.no_grad(): + prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with torch.no_grad(), prediction_context_manager(): loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") # force log the metrics diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 0b2c2cb5e4..172b607b57 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -641,9 +641,9 @@ def peft_module_casting_to_bf16(model): for name, module in model.named_modules(): if isinstance(module, BaseTunerLayer): module = module.to(torch.bfloat16) - if "norm" in name: + elif isinstance(module, torch.nn.LayerNorm) or "norm" in name: module = module.to(torch.float32) - if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): + elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): if hasattr(module, "weight"): if module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) From d1715514de30108bd41b11de849d0017eb94b26d Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Jan 2024 10:20:50 +0100 Subject: [PATCH 12/17] Revert "Address issue #1122 (#1174)" (#1205) This reverts commit d57d0f9ca46a63d370b91791352edda0154576f5. --- trl/trainer/sft_trainer.py | 6 +----- trl/trainer/utils.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8eed33b241..9c06b102ff 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -418,11 +418,7 @@ def tokenize(element): else: self._dataset_sanity_checked = True - return { - "input_ids": outputs["input_ids"], - "labels": outputs["input_ids"], - "attention_mask": outputs["attention_mask"], - } + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} tokenized_dataset = dataset.map( tokenize, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 172b607b57..5c646c3876 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -459,7 +459,6 @@ def __iter__(self): yield { "input_ids": torch.LongTensor(example), "labels": torch.LongTensor(example), - "attention_mask": torch.ones(len(example)), } From 7a95cc86966ec5e59a9347384ba572f93e931827 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:02:31 +0100 Subject: [PATCH 13/17] release: v0.7.9 (#1206) --- setup.py | 2 +- trl/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 102701a774..fb88d57204 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ from setuptools import find_packages, setup -__version__ = "0.7.9.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.7.9" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "torch>=1.4.0", diff --git a/trl/__init__.py b/trl/__init__.py index da7ac999c8..ce88f6216e 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa -__version__ = "0.7.9.dev0" +__version__ = "0.7.9" from .core import set_seed from .environment import TextEnvironment, TextHistory From d6cc88ab2cff35c00caab682807a39399bced3e4 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:06:30 +0100 Subject: [PATCH 14/17] set dev version (#1207) --- setup.py | 2 +- trl/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index fb88d57204..dbbfeb4db9 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ from setuptools import find_packages, setup -__version__ = "0.7.9" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +__version__ = "0.7.10.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ "torch>=1.4.0", diff --git a/trl/__init__.py b/trl/__init__.py index ce88f6216e..8a6b789be0 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa -__version__ = "0.7.9" +__version__ = "0.7.10.dev0" from .core import set_seed from .environment import TextEnvironment, TextHistory From 26da9e80cb667681337b544e2e42205165380611 Mon Sep 17 00:00:00 2001 From: Pablo Vicente Date: Tue, 9 Jan 2024 08:10:22 -0500 Subject: [PATCH 15/17] Check tokenize params on DPOTrainer (#1197) * Check if tokenizer and max len params are None * Update warning messages for missing parameters --- trl/trainer/dpo_trainer.py | 50 ++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2b2f51049c..7509484887 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -276,34 +276,32 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = create_reference_model(model) - if data_collator is None: - if tokenizer is None: - raise ValueError( - "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" - ) - if max_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" - " it will be set to `512` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_length = 512 - if max_prompt_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_prompt_length = 128 + if tokenizer is None: + raise ValueError("tokenizer must be specified to tokenize a DPO dataset.") + if max_length is None: + warnings.warn( + "`max_length` is not set in the DPOTrainer's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the DPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 - if max_target_length is None and self.is_encoder_decoder: - warnings.warn( - "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_target_length = 128 + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + if data_collator is None: data_collator = DPODataCollatorWithPadding( pad_token_id=tokenizer.pad_token_id, label_pad_token_id=label_pad_token_id, From b181e401a73ee943382a9bdcdbc0a1000325b8e0 Mon Sep 17 00:00:00 2001 From: yuta <122957026+yuta0x89@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:24:41 +0900 Subject: [PATCH 16/17] Fix shape descriptions in calculate_loss method (#1204) --- trl/trainer/ddpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index f91421487c..9fab11120f 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -343,11 +343,11 @@ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages Args: latents (torch.Tensor): - The latents sampled from the diffusion model, shape: [batch_size, num_steps, ...] + The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] timesteps (torch.Tensor): The timesteps sampled from the diffusion model, shape: [batch_size] next_latents (torch.Tensor): - The next latents sampled from the diffusion model, shape: [batch_size, num_steps, ...] + The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] log_probs (torch.Tensor): The log probabilities of the latents, shape: [batch_size] advantages (torch.Tensor): From baf3c1c2939986b08660cb023da12310ee293da6 Mon Sep 17 00:00:00 2001 From: mgerstgrasser Date: Tue, 9 Jan 2024 09:21:23 -0800 Subject: [PATCH 17/17] Fix FSDP error (#1196) * Fix FSDP error Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP. * Apply suggestions from code review force return_dict Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- trl/trainer/reward_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 44a5b79223..f2af50b634 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -220,11 +220,13 @@ def compute_loss( rewards_chosen = model( input_ids=inputs["input_ids_chosen"], attention_mask=inputs["attention_mask_chosen"], - )[0] + return_dict=True, + )["logits"] rewards_rejected = model( input_ids=inputs["input_ids_rejected"], attention_mask=inputs["attention_mask_rejected"], - )[0] + return_dict=True, + )["logits"] # calculate loss, optionally modulate with margin if "margin" in inputs: loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()