Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts authored Dec 11, 2024
2 parents c84c56f + eee40c5 commit 43dfb73
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def forward(
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
# TODO: ref input
ref_input=None,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand Down Expand Up @@ -59,6 +59,7 @@ def forward(
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size).
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Other possible arguments that a loss function might need
Expand Down Expand Up @@ -92,6 +93,7 @@ def forward(
compute_nll_loss=compute_nll_loss,
full_target=target,
use_ref_model=use_ref_model,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
**loss_kwargs,
Expand Down Expand Up @@ -301,6 +303,7 @@ def _compute_loss(
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_input=None,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand All @@ -319,6 +322,7 @@ def _compute_loss(
beta (float): Weight for the preference loss.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
Expand Down Expand Up @@ -357,7 +361,7 @@ def _compute_loss(
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_input,
ref_weight,
target_chunk,
ref_bias,
Expand Down

0 comments on commit 43dfb73

Please sign in to comment.