diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 312ae3ac7..ccf74ca04 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -20,35 +20,38 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): @staticmethod def chunk_forward( - input_chunk, weight, target_chunk, bias=None, ignore_index=-100, compute_nll_loss=True + input_chunk, + weight, + target_chunk, + bias=None, + ignore_index=-100, + compute_nll_loss=True, ): len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() if bias is not None: logits_chunk = logits_chunk + bias log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) - + chosen_nll_loss = 0.0 if compute_nll_loss: chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), - reduction="sum", - ignore_index=ignore_index, - ) + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) loss_mask = target_chunk != ignore_index label_chunk = torch.where(loss_mask, target_chunk, 0) - per_token_logps = log_probs_chunk.gather( - -1, label_chunk.unsqueeze(-1) - ).squeeze(-1) - average_log_prob = (per_token_logps * loss_mask).sum( + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( -1 - ) / loss_mask.sum(-1) + ) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_logps = average_log_prob[: len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk :] + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] return chosen_logps, rejected_logps, chosen_nll_loss @staticmethod @@ -225,18 +228,20 @@ def _compute_loss( ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Additional arguments for the loss function. """ - chosen_logps, rejected_logps, chosen_nll_loss = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - weight, - target_chunk, - bias=bias, - ignore_index=ignore_index, - compute_nll_loss=compute_nll_loss, + chosen_logps, rejected_logps, chosen_nll_loss = ( + LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, + ) ) chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) if use_ref_model: with torch.no_grad():