Skip to content

Commit

Permalink
✂️ Reintroduce truncation_mode in DPOTrainer (#2551)
Browse files Browse the repository at this point in the history
* reintroduce truncation mode in DPOTrainer

* move truncation_mode in dataset.map invocation

* truncate full sequence

* "." [ci skip]

* Empty commit

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2025
1 parent a9b54a8 commit fe4b5ef
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
20 changes: 11 additions & 9 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,15 @@ class DPOConfig(TrainingArguments):
Padding value to use. If `None`, the padding value of the tokenizer is used.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Padding value to use for labels.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion.
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the full sequence (prompt + completion).
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
`"keep_start"`.
padding_free (`bool`, *optional*, defaults to `False`):
Whether forward passes are performed without padding by flattening all sequences in the batch
into a single continuous sequence. This approach requires associating a `position_ids` vector to track
Expand Down Expand Up @@ -219,13 +220,6 @@ class DPOConfig(TrainingArguments):
default=-100,
metadata={"help": "Padding value to use for labels."},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the prompt is too long.",
"choices": ["keep_end", "keep_start"],
},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={"help": "Maximum length of the prompt."},
Expand All @@ -238,6 +232,14 @@ class DPOConfig(TrainingArguments):
default=1024,
metadata={"help": "Maximum length of the full sequence (prompt + completion)."},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` "
"and `'keep_start'`.",
"choices": ["keep_end", "keep_start"],
},
)
padding_free: bool = field(
default=False,
metadata={
Expand Down
26 changes: 19 additions & 7 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,12 @@ def make_inputs_require_grad(module, input, output):
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.max_length = args.max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.max_prompt_length = args.max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_completion_length = args.max_completion_length
self.max_length = args.max_length
self.truncation_mode = args.truncation_mode
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.use_num_logits_to_keep = args.use_num_logits_to_keep

Expand Down Expand Up @@ -595,7 +595,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
>>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False)
>>> DPOTrainer.tokenize_row(
... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False
... )
{'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]}
```
"""
Expand Down Expand Up @@ -1145,10 +1147,20 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]
if self.max_length is not None:
if self.truncation_mode == "keep_end":
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
elif self.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
"'keep_start']."
)

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
Expand Down

0 comments on commit fe4b5ef

Please sign in to comment.