Skip to content

Commit

Permalink
📠 Log completions for GRPO (#2772)
Browse files Browse the repository at this point in the history
* log completions

* typo

* wandb

* Fix completions

* Fix style?

* Remove double import

* Revert

* group logging

---------

Co-authored-by: lewtun <[email protected]>
  • Loading branch information
qgallouedec and lewtun authored Feb 7, 2025
1 parent 84d73fd commit 82d12eb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
11 changes: 11 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ class GRPOConfig(TrainingArguments):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
> Parameters that control the logging
log_completions (`bool`, *optional*, defaults to `False`):
Whether to log the completions during training.
"""

# Parameters that control the model and reference model
Expand Down Expand Up @@ -227,3 +232,9 @@ class GRPOConfig(TrainingArguments):
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)

# Parameters that control the logging
log_completions: bool = field(
default=False,
metadata={"help": "Whether to log the completions during training."},
)
26 changes: 24 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def data_collator(features): # No data collation is needed in GRPO

# Initialize the metrics
self._metrics = defaultdict(list)
self.log_completions = args.log_completions

super().__init__(
model=model,
Expand Down Expand Up @@ -534,9 +535,11 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
)

# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
else:
completions = completions_text

rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
Expand Down Expand Up @@ -596,6 +599,25 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
self._metrics["reward"].append(rewards.mean().item())
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())

if (
self.log_completions
and self.state.global_step % self.args.logging_steps == 0
and "wandb" in self.args.report_to
):
import pandas as pd

# For logging
table = {
"step": [str(self.state.global_step)] * len(rewards),
"prompt": gather_object(prompts_text),
"completion": gather_object(completions_text),
"reward": rewards.tolist(),
}
df = pd.DataFrame(table)

if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({"completions": wandb.Table(dataframe=df)})

return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
Expand Down

0 comments on commit 82d12eb

Please sign in to comment.