diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 57afabc80..3b940f315 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -29,7 +29,7 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=False, - # TODO: ref input + ref_input=None, ref_weight=None, ref_bias=None, **loss_kwargs, @@ -59,6 +59,7 @@ def forward( compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size). ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Other possible arguments that a loss function might need @@ -92,6 +93,7 @@ def forward( compute_nll_loss=compute_nll_loss, full_target=target, use_ref_model=use_ref_model, + ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, **loss_kwargs, @@ -301,6 +303,7 @@ def _compute_loss( beta=0.1, compute_nll_loss=True, use_ref_model=False, + ref_input=None, ref_weight=None, ref_bias=None, **loss_kwargs, @@ -319,6 +322,7 @@ def _compute_loss( beta (float): Weight for the preference loss. compute_nll_loss (bool): Whether to compute NLL loss. use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Additional arguments for the loss function. @@ -357,7 +361,7 @@ def _compute_loss( ref_rejected_logits, ref_chosen_nll_loss, ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, + ref_input, ref_weight, target_chunk, ref_bias,