-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
💔 Decouple loss computing and generation in GRPO #2762
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
attention_mask = torch.cat([prompt_mask_repeated, completion_mask], dim=1) # (B*G, P+C) | ||
|
||
# Get the per-token log probabilities for the completions for the model and the reference model | ||
def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
converted to method
) | ||
|
||
# Compute the KL divergence between the model and the reference model | ||
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is later computed in compute_loss
# x - x.detach() allows for preserving gradients from x | ||
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) | ||
per_token_loss = -(per_token_loss - self.beta * per_token_kl) | ||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() | ||
|
||
# Log the metrics | ||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() | ||
self._metrics["completion_length"].append(completion_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is later computed in compute_loss
@@ -366,32 +366,41 @@ def _set_signature_columns_if_needed(self): | |||
if self._signature_columns is None: | |||
self._signature_columns = ["prompt"] | |||
|
|||
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device. | |||
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step. | |||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this method, we now:
- generate
- compute reward
- compute ref log probs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice refactor, LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, looking forward to the next PR!
The motivation behind this PR is to decouple the whole part linked to the generation and calculation of rewards and ref log probs on the one hand, and to the calculation of loss on the other. It's a preparatory PR for the implementation of:
1: minibatching within a same group (reduce memory requirement)
2: the possibility of multiple optimization steps.