Skip to content

Commit

Permalink
Add paper link and formula for preference loss (#449)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
ByronHsu authored Dec 9, 2024
1 parent 24bdb2c commit 8bcb859
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 42 deletions.
26 changes: 16 additions & 10 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,25 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
"""
Compute odds-ratio loss.
Paper: https://arxiv.org/pdf/2401.08417
Formula:
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
Where:
- π_θ(y|x): Policy (model) probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- σ: Sigmoid function
- β: Temperature parameter
- E: Expected value over the dataset D
- D: Dataset of preferences
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the CPO loss
"""
logits = beta * (chosen_logps - rejected_logps)
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
Expand All @@ -34,12 +48,6 @@ def forward(
compute_nll_loss=True,
compiled=True,
):
"""
Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
Handles both the forward and backward pass of the final linear layer with CPO loss.
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
"""

return LigerFusedLinearPreferenceBase.forward(
ctx,
_input,
Expand All @@ -56,9 +64,7 @@ def forward(

@staticmethod
def backward(ctx, *grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None


Expand Down
32 changes: 20 additions & 12 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,28 @@ def preference_loss_fn(
beta=0.1,
):
"""
Compute DPO loss (Direct Preference Optimization).
Paper: https://arxiv.org/pdf/2305.18290
Formula:
L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
Where:
- π(y|x): Policy (model) probability
- π_ref(y|x): Reference model probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- β: Weight for the direct preference loss
- E: Expected value over the dataset
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the direct preference loss.
chosen_logps: Log probabilities of chosen tokens (batch_size,)
rejected_logps: Log probabilities of rejected tokens (batch_size,)
full_target: Non chunked full target tensor
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
beta: Weight for the direct preference loss
"""

if ref_chosen_logps is None:
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
if ref_rejected_logps is None:
Expand Down Expand Up @@ -53,10 +67,6 @@ def forward(
compiled=True,
use_ref_model=True,
):
"""
Fused linear layer with DPO (Direct Preference Optimization) loss.
Handles both the forward and backward pass of the final linear layer with DPO loss.
"""
return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
_input=_input,
Expand All @@ -75,9 +85,7 @@ def forward(

@staticmethod
def backward(ctx, *grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None, None, None


Expand Down
24 changes: 15 additions & 9 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,24 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
"""
Compute odds-ratio loss.
Paper: https://arxiv.org/pdf/2403.07691
Formula:
Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
Where:
- P_θ(y|x): Policy (model) probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- σ: Sigmoid function
- β: Weight for the odds ratio loss
- odds_θ: Odds function for the policy
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the odds ratio loss.
"""
log_odds = (chosen_logps - rejected_logps) - (
Expand Down Expand Up @@ -44,12 +58,6 @@ def forward(
compute_nll_loss=True,
compiled=True,
):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
Handles both the forward and backward pass of the final linear layer with ORPO loss.
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
"""

return LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
_input=_input,
Expand All @@ -65,9 +73,7 @@ def forward(

@staticmethod
def backward(ctx, *grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None


Expand Down
28 changes: 17 additions & 11 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,26 @@ def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
):
"""
Compute odds-ratio loss.
Paper: https://arxiv.org/pdf/2405.14734
Formula:
L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
Where:
- π_θ(y|x): Policy (model) probability
- y_w: Chosen sequence
- y_l: Rejected sequence
- |y_w|, |y_l|: Sequence lengths
- σ: Sigmoid function
- β: beta weight
- γ: gemma margin term
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
gamma (float): The simpo gamma, margin term.
full_target: Non chunked full target tensor
beta (float): beta weight
gamma (float): gemma margin term
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
Expand All @@ -38,12 +52,6 @@ def forward(
compiled=True,
gamma=0.5,
):
"""
Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
Handles both the forward and backward pass of the final linear layer with SimPO loss.
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
"""

return LigerFusedLinearPreferenceBase.forward(
ctx,
_input,
Expand All @@ -61,9 +69,7 @@ def forward(

@staticmethod
def backward(ctx, *grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None, None


Expand Down

0 comments on commit 8bcb859

Please sign in to comment.