diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 903bb719ca..886022a612 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -480,6 +480,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names)