Skip to content

Commit

Permalink
checkstyle
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam15s committed Nov 21, 2024
1 parent f84f16a commit 0fda3b8
Showing 1 changed file with 29 additions and 24 deletions.
53 changes: 29 additions & 24 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,38 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):

@staticmethod
def chunk_forward(
input_chunk, weight, target_chunk, bias=None, ignore_index=-100, compute_nll_loss=True
input_chunk,
weight,
target_chunk,
bias=None,
ignore_index=-100,
compute_nll_loss=True,
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(
-1, label_chunk.unsqueeze(-1)
).squeeze(-1)
average_log_prob = (per_token_logps * loss_mask).sum(
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
) / loss_mask.sum(-1)
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[: len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk :]
chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]
return chosen_logps, rejected_logps, chosen_nll_loss

@staticmethod
Expand Down Expand Up @@ -225,18 +228,20 @@ def _compute_loss(
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
"""
chosen_logps, rejected_logps, chosen_nll_loss = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
chosen_logps, rejected_logps, chosen_nll_loss = (
LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
)
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)

if use_ref_model:
with torch.no_grad():
Expand Down

0 comments on commit 0fda3b8

Please sign in to comment.