Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam15s committed Nov 21, 2024
1 parent 15cb383 commit f84f16a
Showing 1 changed file with 46 additions and 53 deletions.
99 changes: 46 additions & 53 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,37 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
raise NotImplementedError("Preference loss function must be implemented.")

@staticmethod
def get_ref_logps(
input_chunk, ref_weight, target_chunk, ref_bias=None, ignore_index=-100
def chunk_forward(
input_chunk, weight, target_chunk, bias=None, ignore_index=-100, compute_nll_loss=True
):
with torch.no_grad():
ref_logits_chunk = input_chunk @ ref_weight.t()
if ref_bias is not None:
ref_logits_chunk = ref_logits_chunk + ref_bias
ref_log_probs_chunk = F.log_softmax(ref_logits_chunk.float(), dim=-1)
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)
loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

ref_per_token_logps = ref_log_probs_chunk.gather(
-1, label_chunk.unsqueeze(-1)
).squeeze(-1)
ref_average_log_prob = (ref_per_token_logps * loss_mask).sum(
-1
) / loss_mask.sum(-1)
per_token_logps = log_probs_chunk.gather(
-1, label_chunk.unsqueeze(-1)
).squeeze(-1)
average_log_prob = (per_token_logps * loss_mask).sum(
-1
) / loss_mask.sum(-1)

ref_chosen_logps = ref_average_log_prob[: input_chunk.shape[0] // 2]
ref_rejected_logps = ref_average_log_prob[input_chunk.shape[0] // 2 :]
return ref_chosen_logps, ref_rejected_logps
chosen_logps = average_log_prob[: len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk :]
return chosen_logps, rejected_logps, chosen_nll_loss

@staticmethod
def forward(
Expand Down Expand Up @@ -216,47 +225,31 @@ def _compute_loss(
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
"""
len_chosen_chunk = target_chunk.shape[0] // 2

logits_chunk = input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)
chosen_nll_loss = (
chosen_logps, rejected_logps, chosen_nll_loss = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]

if use_ref_model:
ref_chosen_logps, ref_rejected_logps = (
LigerFusedLinearPreferenceBase.get_ref_logps(
input_chunk,
ref_weight,
target_chunk,
ref_bias=ref_bias,
ignore_index=ignore_index,
with torch.no_grad():
ref_chosen_logps, ref_rejected_logps, _ = (
LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False,
)
)
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps

Expand Down

0 comments on commit f84f16a

Please sign in to comment.