From f84f16a79a7f7fe4c4d6f1517bb14e7908b705c0 Mon Sep 17 00:00:00 2001 From: shivam15s Date: Thu, 21 Nov 2024 23:56:38 +0000 Subject: [PATCH] refactor code --- .../chunked_loss/fused_linear_preference.py | 99 +++++++++---------- 1 file changed, 46 insertions(+), 53 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index da636b6c7..312ae3ac7 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -19,28 +19,37 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): raise NotImplementedError("Preference loss function must be implemented.") @staticmethod - def get_ref_logps( - input_chunk, ref_weight, target_chunk, ref_bias=None, ignore_index=-100 + def chunk_forward( + input_chunk, weight, target_chunk, bias=None, ignore_index=-100, compute_nll_loss=True ): - with torch.no_grad(): - ref_logits_chunk = input_chunk @ ref_weight.t() - if ref_bias is not None: - ref_logits_chunk = ref_logits_chunk + ref_bias - ref_log_probs_chunk = F.log_softmax(ref_logits_chunk.float(), dim=-1) + len_chosen_chunk = target_chunk.shape[0] // 2 + logits_chunk = input_chunk @ weight.t() + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + + chosen_nll_loss = 0.0 + if compute_nll_loss: + chosen_nll_loss = F.nll_loss( + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) - ref_per_token_logps = ref_log_probs_chunk.gather( - -1, label_chunk.unsqueeze(-1) - ).squeeze(-1) - ref_average_log_prob = (ref_per_token_logps * loss_mask).sum( - -1 - ) / loss_mask.sum(-1) + per_token_logps = log_probs_chunk.gather( + -1, label_chunk.unsqueeze(-1) + ).squeeze(-1) + average_log_prob = (per_token_logps * loss_mask).sum( + -1 + ) / loss_mask.sum(-1) - ref_chosen_logps = ref_average_log_prob[: input_chunk.shape[0] // 2] - ref_rejected_logps = ref_average_log_prob[input_chunk.shape[0] // 2 :] - return ref_chosen_logps, ref_rejected_logps + chosen_logps = average_log_prob[: len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk :] + return chosen_logps, rejected_logps, chosen_nll_loss @staticmethod def forward( @@ -216,47 +225,31 @@ def _compute_loss( ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Additional arguments for the loss function. """ - len_chosen_chunk = target_chunk.shape[0] // 2 - - logits_chunk = input_chunk @ weight.t() # chunk_size x V - if bias is not None: - logits_chunk = logits_chunk + bias - log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) - - chosen_nll_loss = 0.0 - if compute_nll_loss: - chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), - reduction="sum", - ignore_index=ignore_index, - ) - chosen_nll_loss = ( + chosen_logps, rejected_logps, chosen_nll_loss = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, + ) + chosen_nll_loss = ( chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() ) - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) - - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( - -1 - ) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] - if use_ref_model: - ref_chosen_logps, ref_rejected_logps = ( - LigerFusedLinearPreferenceBase.get_ref_logps( - input_chunk, - ref_weight, - target_chunk, - ref_bias=ref_bias, - ignore_index=ignore_index, + with torch.no_grad(): + ref_chosen_logps, ref_rejected_logps, _ = ( + LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, + ) ) - ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps