Skip to content

Commit

Permalink
Fix Preference Loss and Refactor for Readability (#484)
Browse files Browse the repository at this point in the history
## Summary

Thanks to @winglian and @shivam15s noticed and fixed this
#481.

This PR suggests negating the preference loss terms to align with the
formulas in the docstrings, while maintaining the base preference
structure as `nll_loss + preference_loss`. This would make our loss
computations more consistent since both terms would represent losses to
be minimized.

[UPDATE: It seems like being addressed now in
[here](3205342#diff-3048cb37b97e27515852c200994f3257b8ae33a465421d05184713377c0895b1R150)]
This PR also tightened the tolerance in case of encountering a similar
issue.

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: Shivam Sahni <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent 79e2b02 commit 15a2f58
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def preference_loss_fn(
"""
logits = beta * (chosen_logps - rejected_logps)
loss = (
F.logsigmoid(logits) * (1 - label_smoothing)
+ F.logsigmoid(-logits) * label_smoothing
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def _compute_loss(
else:
preference_loss, aux_outputs = preference_loss_outputs, []

loss = alpha * chosen_nll_loss - preference_loss
loss = alpha * chosen_nll_loss + preference_loss
return_vars = (
chosen_logps,
rejected_logps,
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
- torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)

chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def preference_loss_fn(
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = (
F.logsigmoid(logits) * (1 - label_smoothing)
+ F.logsigmoid(-logits) * label_smoothing
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

return loss
Expand Down
8 changes: 4 additions & 4 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def alignment_loss(
if self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "simpo":
logits = logits - (self.simpo_gamma / self.beta)
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion test/chunked_loss/test_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def alignment_loss(
- torch.log1p(-torch.exp(policy_rejected_logps))
)
ratio = F.logsigmoid(log_odds)
losses = self.beta * ratio
losses = -self.beta * ratio

chosen_rewards = self.beta * policy_chosen_logps
rejected_rewards = self.beta * policy_rejected_logps
Expand Down
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def get_batch_loss_metrics(
else:
losses, aggregated_aux_outputs = alignment_loss_outputs, []
# full loss
loss = policy_nll_loss * self.alpha - losses.mean()
loss = policy_nll_loss * self.alpha + losses.mean()
return_vars = (
policy_chosen_logps,
policy_rejected_logps,
Expand Down

0 comments on commit 15a2f58

Please sign in to comment.