Skip to content

Commit

Permalink
softcap + cpo test
Browse files Browse the repository at this point in the history
  • Loading branch information
ryankert01 committed Dec 12, 2024
1 parent 0df301c commit 6d2a935
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
12 changes: 3 additions & 9 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
if compiled:
fused_fwd_bwd = torch.compile(fused_fwd_bwd)

if softcap is not None:
_input = softcap * torch.tanh(_input / softcap)

len_chosen = target.shape[0] // 2
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
Expand Down Expand Up @@ -284,16 +287,11 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
softcap=None,
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
if softcap is not None:
logits_chunk = logits_chunk / softcap
logits_chunk = torch.tanh(logits_chunk)
logits_chunk = logits_chunk * softcap
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
Expand Down Expand Up @@ -343,7 +341,6 @@ def _compute_loss(
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
softcap=None,
**loss_kwargs,
):
"""
Expand All @@ -362,7 +359,6 @@ def _compute_loss(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
Expand All @@ -378,7 +374,6 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
softcap=softcap,
)
chosen_nll_loss = (
chosen_nll_loss
Expand Down Expand Up @@ -406,7 +401,6 @@ def _compute_loss(
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
softcap=softcap,
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down
22 changes: 16 additions & 6 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed
from typing import Tuple
from typing import Tuple, Optional

import pytest
import torch
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(
alpha: float = 1.0,
loss_type: str = "sigmoid",
simpo_gamma: float = 0.5,
softcap: Optional[float] = None,
):
super().__init__()
self.lin = torch.nn.Linear(
Expand All @@ -99,8 +100,12 @@ def __init__(
loss_type=loss_type,
simpo_gamma=simpo_gamma,
).get_batch_loss_metrics
self.softcap = softcap

def forward(self, x, y):
logits = self.lin(x).to(torch.float32)
if self.softcap is not None and self.softcap != 0.0:
logits = self.softcap * torch.tanh(logits / self.softcap)
return self.cpo_loss(self.lin.weight, x, y, self.lin.bias)


Expand All @@ -114,13 +119,14 @@ def __init__(
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
softcap: Optional[float] = None,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.cpo_loss = LigerFusedLinearCPOLoss(
ignore_index=ignore_index, beta=beta, alpha=alpha
ignore_index=ignore_index, beta=beta, alpha=alpha, softcap=softcap
)

def forward(self, x, y):
Expand All @@ -135,18 +141,20 @@ def forward(self, x, y):
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
"scalar, dtype, atol, rtol, softcap",
[
(1.0, torch.bfloat16, 5e-3, 5e-3),
(1.0, torch.float32, 1e-5, 5e-4),
(1.0, torch.bfloat16, 5e-3, 5e-3, None),
(1.0, torch.float32, 1e-5, 5e-4, None),
(1.0, torch.bfloat16, 5e-3, 5e-3, 30),
(1.0, torch.float32, 5e-3, 5e-3, 30),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]
)
def test_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha
B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha, softcap
):
B = 2 * B # cpo loss requires B to be even

Expand All @@ -157,6 +165,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
softcap=softcap,
)
liger_lm_head_cpo = LigerLMHeadCPO(
H=H,
Expand All @@ -165,6 +174,7 @@ def test_correctness(
bias=bias,
ignore_index=ignore_index,
beta=beta,
softcap=softcap,
)

torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn(
Expand Down

0 comments on commit 6d2a935

Please sign in to comment.