Skip to content

Commit

Permalink
log completions
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 5, 2025
1 parent af4ad47 commit 82750d3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
11 changes: 11 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ class GRPOConfig(TrainingArguments):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
> Parameters taht control the logging
log_completions (`bool`, *optional*, defaults to `False`):
Whether to log the completions during training.
"""

# Parameters that control the model and reference model
Expand Down Expand Up @@ -233,3 +238,9 @@ class GRPOConfig(TrainingArguments):
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)

# Parameters that control the logging
log_completions: bool = field(
default=False,
metadata={"help": "Whether to log the completions during training."},
)
10 changes: 10 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def data_collator(features): # No data collation is needed in GRPO

# Initialize the metrics
self._metrics = defaultdict(list)
self.log_completions = args.log_completions

super().__init__(
model=model,
Expand Down Expand Up @@ -536,6 +537,15 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

if (
self.accelerator.is_main_process
and self.log_completions
and self.state.global_step % self.args.logging_steps == 0
):
for prompt, completion, reward in zip(prompts, completions, rewards):
text = f"---\n\033[31m{prompt}\033[34m{completion}\033[0m\n--- \033[32mReward: {reward}\033[0m ---"
self.accelerator.print(text)

return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
Expand Down

0 comments on commit 82750d3

Please sign in to comment.