Skip to content

Commit

Permalink
trick to keep messages
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 7, 2025
1 parent 85f1c5a commit 4e5363e
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def __init__(
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id

def _prepare_dataset(self, dataset, *args):
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
dataset = dataset.add_column("_messages", dataset["messages"])
dataset = super()._prepare_dataset(dataset, *args)
dataset = dataset.rename_column("_messages", "messages")
return dataset

@staticmethod
def generalized_jsd_loss(
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
Expand Down

0 comments on commit 4e5363e

Please sign in to comment.