From eee40c5ebc8398b17d513b839700876a998366ba Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Wed, 11 Dec 2024 16:30:00 -0500 Subject: [PATCH] Add ref_input parameter to support separate inputs for reference model (#467) This PR fixes #447 by adding support for separate inputs for the reference model. ### Changes - Add `ref_input` parameter to `forward()` and `_compute_loss()` methods - Use `ref_input` for reference model calculations if provided, otherwise fallback to using the main input - Update docstrings to document the new parameter ### Testing The changes are backward compatible - if `ref_input` is not provided, it will use the main input for reference model calculations, maintaining the current behavior. Fixes #447 --------- Co-authored-by: openhands --- src/liger_kernel/chunked_loss/fused_linear_preference.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 57afabc80..3b940f315 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -29,7 +29,7 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=False, - # TODO: ref input + ref_input=None, ref_weight=None, ref_bias=None, **loss_kwargs, @@ -59,6 +59,7 @@ def forward( compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size). ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Other possible arguments that a loss function might need @@ -92,6 +93,7 @@ def forward( compute_nll_loss=compute_nll_loss, full_target=target, use_ref_model=use_ref_model, + ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, **loss_kwargs, @@ -301,6 +303,7 @@ def _compute_loss( beta=0.1, compute_nll_loss=True, use_ref_model=False, + ref_input=None, ref_weight=None, ref_bias=None, **loss_kwargs, @@ -319,6 +322,7 @@ def _compute_loss( beta (float): Weight for the preference loss. compute_nll_loss (bool): Whether to compute NLL loss. use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Additional arguments for the loss function. @@ -357,7 +361,7 @@ def _compute_loss( ref_rejected_logits, ref_chosen_nll_loss, ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, + ref_input, ref_weight, target_chunk, ref_bias,