Skip to content

Commit

Permalink
Merge branch 'main' into fix-grpo-logits-calc
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Jan 31, 2025
2 parents f5be3ab + 1c35a48 commit 15e0b7b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 22 deletions.
6 changes: 3 additions & 3 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ def test_dpo_loss_js_div_f(self):
)
self.assertTrue(torch.isfinite(losses).cpu().numpy().all())

def test_dpo_trainer_use_num_logits_to_keep(self):
def test_dpo_trainer_use_logits_to_keep(self):
model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -1087,7 +1087,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self):
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
use_num_logits_to_keep=True,
use_logits_to_keep=True,
rpo_alpha=0.5,
report_to="none",
)
Expand All @@ -1104,7 +1104,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self):
eval_dataset=dummy_dataset["test"],
)

training_args.use_num_logits_to_keep = False
training_args.use_logits_to_keep = False
trainer2 = DPOTrainer(
model=model,
ref_model=None,
Expand Down
22 changes: 20 additions & 2 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Optional, Union
Expand Down Expand Up @@ -57,7 +58,7 @@ class DPOConfig(TrainingArguments):
this flag to `True`.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
use_num_logits_to_keep (`bool`, *optional*, defaults to `False`):
use_logits_to_keep (`bool`, *optional*, defaults to `False`):
If `True`, only a specified number of logits are computed in the forward pass. This can be useful for
saving memory and speeding up training by not computing the logits for all tokens, especially in
scenarios when working with very long prompts where labels are ignored (-100).
Expand Down Expand Up @@ -197,7 +198,7 @@ class DPOConfig(TrainingArguments):
default=True,
metadata={"help": "Whether to disable dropout in the model and reference model."},
)
use_num_logits_to_keep: bool = field(
use_logits_to_keep: bool = field(
default=False,
metadata={
"help": "If `True`, only a specified number of logits are computed in the forward pass. This can be "
Expand Down Expand Up @@ -384,3 +385,20 @@ class DPOConfig(TrainingArguments):
"Comet during evaluation."
},
)

# Deprecated parameters
use_num_logits_to_keep: bool = field(
default=False,
metadata={"help": "Deprecated. Use `use_logits_to_keep` instead."},
)

def __post_init__(self):
super().__post_init__()

if self.use_num_logits_to_keep:
warnings.warn(
"`use_num_logits_to_keep` is deprecated and will be remove in version 0.17.0. Use "
"`use_logits_to_keep` instead.",
DeprecationWarning,
)
self.use_logits_to_keep = self.use_num_logits_to_keep
18 changes: 9 additions & 9 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def make_inputs_require_grad(module, input, output):
self.max_length = args.max_length
self.truncation_mode = args.truncation_mode
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.use_num_logits_to_keep = args.use_num_logits_to_keep
self.use_logits_to_keep = args.use_logits_to_keep

if args.padding_free:
if model.config._attn_implementation != "flash_attention_2":
Expand Down Expand Up @@ -1167,14 +1167,14 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
"'keep_start']."
)

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
if self.use_logits_to_keep:
# Compute logits_to_keep based on loss_mask pattern:
# [[0, 0, 0, x, x, x, x],
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
model_kwargs["num_logits_to_keep"] = num_logits_to_keep
logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
model_kwargs["logits_to_keep"] = logits_to_keep

if self.padding_free:
# Flatten the input_ids, position_ids, and loss_mask
Expand All @@ -1194,15 +1194,15 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
labels = torch.roll(input_ids, shifts=-1, dims=1)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()

if self.use_num_logits_to_keep:
if self.use_logits_to_keep:
# Align labels with logits
# logits: -, -, [x2, x3, x4, x5, x6]
# ^ --------- ^ after logits[:, :-1, :]
# labels: [y0, y1, y2, y3, y4, y5, y6]
# ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
# ^ --------- ^ with logits_to_keep=4, [:, -4:]
# loss_mask: [0, 0, 0, 1, 1, 1, 1]
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]
labels = labels[:, -logits_to_keep:]
loss_mask = loss_mask[:, -logits_to_keep:]

if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
Expand Down
16 changes: 8 additions & 8 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,33 +439,33 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
attention_mask = torch.cat([prompt_mask_repeated, completion_mask], dim=1) # (B*G, P+C)

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, attention_mask, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids, attention_mask=attention_mask, num_logits_to_keep=num_logits_to_keep + 1
input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = get_per_token_logps(model, prompt_completion_ids, attention_mask, num_logits_to_keep)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = get_per_token_logps(model, prompt_completion_ids, attention_mask, logits_to_keep)

with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, num_logits_to_keep
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(
model, prompt_completion_ids, attention_mask, num_logits_to_keep
model, prompt_completion_ids, attention_mask, logits_to_keep
)

# Compute the KL divergence between the model and the reference model
Expand Down

0 comments on commit 15e0b7b

Please sign in to comment.