From fe4b5efe4e23f4331ba9c5b0c8bd92dc8302c287 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 22 Jan 2025 15:33:50 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=82=EF=B8=8F=20Reintroduce=20`truncation?= =?UTF-8?q?=5Fmode`=20in=20`DPOTrainer`=20(#2551)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * reintroduce truncation mode in DPOTrainer * move truncation_mode in dataset.map invocation * truncate full sequence * "." [ci skip] * Empty commit --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- trl/trainer/dpo_config.py | 20 +++++++++++--------- trl/trainer/dpo_trainer.py | 26 +++++++++++++++++++------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b7c18e11cc..a3cdc28d28 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -71,14 +71,15 @@ class DPOConfig(TrainingArguments): Padding value to use. If `None`, the padding value of the tokenizer is used. label_pad_token_id (`int`, *optional*, defaults to `-100`): Padding value to use for labels. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`. max_prompt_length (`int` or `None`, *optional*, defaults to `512`): Maximum length of the prompt. max_completion_length (`int` or `None`, *optional*, defaults to `None`): Maximum length of the completion. max_length (`int` or `None`, *optional*, defaults to `1024`): Maximum length of the full sequence (prompt + completion). + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and + `"keep_start"`. padding_free (`bool`, *optional*, defaults to `False`): Whether forward passes are performed without padding by flattening all sequences in the batch into a single continuous sequence. This approach requires associating a `position_ids` vector to track @@ -219,13 +220,6 @@ class DPOConfig(TrainingArguments): default=-100, metadata={"help": "Padding value to use for labels."}, ) - truncation_mode: str = field( - default="keep_end", - metadata={ - "help": "Truncation mode to use when the prompt is too long.", - "choices": ["keep_end", "keep_start"], - }, - ) max_prompt_length: Optional[int] = field( default=512, metadata={"help": "Maximum length of the prompt."}, @@ -238,6 +232,14 @@ class DPOConfig(TrainingArguments): default=1024, metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " + "and `'keep_start'`.", + "choices": ["keep_end", "keep_start"], + }, + ) padding_free: bool = field( default=False, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 218f1af5a9..903bb719ca 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -388,12 +388,12 @@ def make_inputs_require_grad(module, input, output): if self.ref_model is not None: disable_dropout_in_model(self.ref_model) - self.max_length = args.max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id self.max_prompt_length = args.max_prompt_length - self.truncation_mode = args.truncation_mode self.max_completion_length = args.max_completion_length + self.max_length = args.max_length + self.truncation_mode = args.truncation_mode self.precompute_ref_log_probs = args.precompute_ref_log_probs self.use_num_logits_to_keep = args.use_num_logits_to_keep @@ -595,7 +595,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l >>> from transformers import GPT2Tokenizer >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} - >>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False) + >>> DPOTrainer.tokenize_row( + ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False + ... ) {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} ``` """ @@ -1145,10 +1147,20 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - loss_mask = loss_mask[:, : self.args.max_length] + if self.max_length is not None: + if self.truncation_mode == "keep_end": + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + loss_mask = loss_mask[:, -self.max_length :] + elif self.truncation_mode == "keep_start": + input_ids = input_ids[:, : self.max_length] + attention_mask = attention_mask[:, : self.max_length] + loss_mask = loss_mask[:, : self.max_length] + else: + raise ValueError( + f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " + "'keep_start']." + ) if self.use_num_logits_to_keep: # Compute num_logits_to_keep based on loss_mask pattern: