From 1c35a48b50f54b92c6b820437aaf75c4e3d777ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 31 Jan 2025 20:19:39 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8F=B0=20`num=5Flogits=5Fto=5Fkeep`=20to?= =?UTF-8?q?=20`logits=5Fto=5Fkeep`=20(#2721)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_dpo_trainer.py | 6 +++--- trl/trainer/dpo_config.py | 22 ++++++++++++++++++++-- trl/trainer/dpo_trainer.py | 18 +++++++++--------- trl/trainer/grpo_trainer.py | 16 ++++++++-------- 4 files changed, 40 insertions(+), 22 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index c4a0232ee3..4e1a72ce9a 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1070,7 +1070,7 @@ def test_dpo_loss_js_div_f(self): ) self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) - def test_dpo_trainer_use_num_logits_to_keep(self): + def test_dpo_trainer_use_logits_to_keep(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token @@ -1087,7 +1087,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): learning_rate=9e-1, eval_strategy="steps", beta=0.1, - use_num_logits_to_keep=True, + use_logits_to_keep=True, rpo_alpha=0.5, report_to="none", ) @@ -1104,7 +1104,7 @@ def test_dpo_trainer_use_num_logits_to_keep(self): eval_dataset=dummy_dataset["test"], ) - training_args.use_num_logits_to_keep = False + training_args.use_logits_to_keep = False trainer2 = DPOTrainer( model=model, ref_model=None, diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 55c6ecc7c8..09b6e35dea 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Optional, Union @@ -57,7 +58,7 @@ class DPOConfig(TrainingArguments): this flag to `True`. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model and reference model. - use_num_logits_to_keep (`bool`, *optional*, defaults to `False`): + use_logits_to_keep (`bool`, *optional*, defaults to `False`): If `True`, only a specified number of logits are computed in the forward pass. This can be useful for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios when working with very long prompts where labels are ignored (-100). @@ -197,7 +198,7 @@ class DPOConfig(TrainingArguments): default=True, metadata={"help": "Whether to disable dropout in the model and reference model."}, ) - use_num_logits_to_keep: bool = field( + use_logits_to_keep: bool = field( default=False, metadata={ "help": "If `True`, only a specified number of logits are computed in the forward pass. This can be " @@ -384,3 +385,20 @@ class DPOConfig(TrainingArguments): "Comet during evaluation." }, ) + + # Deprecated parameters + use_num_logits_to_keep: bool = field( + default=False, + metadata={"help": "Deprecated. Use `use_logits_to_keep` instead."}, + ) + + def __post_init__(self): + super().__post_init__() + + if self.use_num_logits_to_keep: + warnings.warn( + "`use_num_logits_to_keep` is deprecated and will be remove in version 0.17.0. Use " + "`use_logits_to_keep` instead.", + DeprecationWarning, + ) + self.use_logits_to_keep = self.use_num_logits_to_keep diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 886022a612..cd3c3b4dca 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -395,7 +395,7 @@ def make_inputs_require_grad(module, input, output): self.max_length = args.max_length self.truncation_mode = args.truncation_mode self.precompute_ref_log_probs = args.precompute_ref_log_probs - self.use_num_logits_to_keep = args.use_num_logits_to_keep + self.use_logits_to_keep = args.use_logits_to_keep if args.padding_free: if model.config._attn_implementation != "flash_attention_2": @@ -1167,14 +1167,14 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to "'keep_start']." ) - if self.use_num_logits_to_keep: - # Compute num_logits_to_keep based on loss_mask pattern: + if self.use_logits_to_keep: + # Compute logits_to_keep based on loss_mask pattern: # [[0, 0, 0, x, x, x, x], # [0, 0, 0, x, x, x, 0]] # ^ start computing logits from here ([:, -(7-3+1):]) first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - num_logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label - model_kwargs["num_logits_to_keep"] = num_logits_to_keep + logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label + model_kwargs["logits_to_keep"] = logits_to_keep if self.padding_free: # Flatten the input_ids, position_ids, and loss_mask @@ -1194,15 +1194,15 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to labels = torch.roll(input_ids, shifts=-1, dims=1) loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() - if self.use_num_logits_to_keep: + if self.use_logits_to_keep: # Align labels with logits # logits: -, -, [x2, x3, x4, x5, x6] # ^ --------- ^ after logits[:, :-1, :] # labels: [y0, y1, y2, y3, y4, y5, y6] - # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] + # ^ --------- ^ with logits_to_keep=4, [:, -4:] # loss_mask: [0, 0, 0, 1, 1, 1, 1] - labels = labels[:, -num_logits_to_keep:] - loss_mask = loss_mask[:, -num_logits_to_keep:] + labels = labels[:, -logits_to_keep:] + loss_mask = loss_mask[:, -logits_to_keep:] if logits.shape[:2] != labels.shape[:2]: # for llava, the returned logits include the image tokens (placed before the text tokens) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9e10328072..f130d5155f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -427,28 +427,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N completion_ids = prompt_completion_ids[:, prompt_length:] # Get the per-token log probabilities for the completions for the model and the reference model - def get_per_token_logps(model, input_ids, num_logits_to_keep): - # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded - logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits # (B, L, V) + def get_per_token_logps(model, input_ids, logits_to_keep): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids, logits_to_keep=logits_to_keep + 1).logits # (B, L, V) logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. per_token_logps = [] - for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]): + for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]): log_probs = logits_row.log_softmax(dim=-1) token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) per_token_logps.append(token_log_prob) return torch.stack(per_token_logps) - num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + per_token_logps = get_per_token_logps(model, prompt_completion_ids, logits_to_keep) with torch.inference_mode(): if self.ref_model is not None: - ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, num_logits_to_keep) + ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, logits_to_keep) else: with self.accelerator.unwrap_model(model).disable_adapter(): - ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) + ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, logits_to_keep) # Compute the KL divergence between the model and the reference model per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1