From 40c238395e345e6013f899b3768b53c73e60844b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 23 Jan 2025 12:12:06 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=A5=9E=20Fix=20DPO=20gradient=20accumulat?= =?UTF-8?q?ion=20loss=20scaling=20(#2615)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix DPO for gradient accumulation * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py * Update trl/trainer/dpo_trainer.py --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 903bb719ca..886022a612 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -480,6 +480,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names)