Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💔 Decouple loss computing and generation in GRPO #2762

Merged
merged 1 commit into from
Feb 4, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Feb 4, 2025

The motivation behind this PR is to decouple the whole part linked to the generation and calculation of rewards and ref log probs on the one hand, and to the calculation of loss on the other. It's a preparatory PR for the implementation of:

1: minibatching within a same group (reduce memory requirement)
2: the possibility of multiple optimization steps.

@qgallouedec qgallouedec changed the title Decouple loss computing and generation in GRPO 💔 Decouple loss computing and generation in GRPO Feb 4, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

attention_mask = torch.cat([prompt_mask_repeated, completion_mask], dim=1) # (B*G, P+C)

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted to method

)

# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is later computed in compute_loss

Comment on lines -524 to -531
# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is later computed in compute_loss

@@ -366,32 +366,41 @@ def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
self._signature_columns = ["prompt"]

# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this method, we now:

  • generate
  • compute reward
  • compute ref log probs

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactor, LGTM!

Copy link
Collaborator

@edbeeching edbeeching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, looking forward to the next PR!

@qgallouedec qgallouedec merged commit 1f344c9 into main Feb 4, 2025
13 of 14 checks passed
@qgallouedec qgallouedec deleted the decouple-generation-and-loss branch February 4, 2025 12:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants