From 21bacccd107b40146e599e2c27af46d7d157f174 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Dec 2024 22:43:02 +0100 Subject: [PATCH 1/7] [Transformer] fix ORPO loss for MOE models (#479) ## Summary Add missing MOE loss when specified in the trainer. - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- src/liger_kernel/transformers/trainer/orpo_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index 3605b9f1b..184430ac1 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -125,6 +125,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels): outputs.last_hidden_state, concatenated_batch["concatenated_labels"], ) + # if aux_loss_enabled, add the aux_loss to the orpo_loss + if self.aux_loss_enabled: + orpo_loss += self.aux_loss_coef * outputs.aux_loss + return orpo_loss, aux_outputs def get_batch_loss_metrics( From ac5667471e24434c378781c5400b19d595d05fd8 Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Mon, 16 Dec 2024 22:01:19 -0800 Subject: [PATCH 2/7] fix: correct typos in docstrings (#482) - Fix 'transfomers' to 'transformers' in mixtral.py - Fix 'Emebedding' to 'Embedding' in orpo_trainer.py ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: byhsu@linkedin.com --- src/liger_kernel/transformers/model/mixtral.py | 2 +- src/liger_kernel/transformers/trainer/orpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py index 22fea53da..145bc78cd 100644 --- a/src/liger_kernel/transformers/model/mixtral.py +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -38,7 +38,7 @@ def lce_forward_deprecated( cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" - Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy + Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy Args: diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index 184430ac1..04391fa5f 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -17,7 +17,7 @@ class _FSDPForwardRedirection: This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`) - will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of + will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just the `lm_head` part of a model, we need this trick too to properly get its params all-gathered. From 61eefe9a4429459351979dc7fe1de746fd7ca86f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 18 Dec 2024 23:19:38 +0100 Subject: [PATCH 3/7] fix chosen_nll_loss in chunked losses (#486) ## Summary Fix the nll loss in the the chunked loses when the model is a decoder only model, by shifting the logits and targets - Hardware Type: - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/cpo_loss.py | 11 ++- src/liger_kernel/chunked_loss/dpo_loss.py | 8 +- .../chunked_loss/fused_linear_preference.py | 92 ++++++++++++------- src/liger_kernel/chunked_loss/orpo_loss.py | 10 +- test/utils.py | 17 +++- 5 files changed, 101 insertions(+), 37 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..1d771753e 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -47,6 +47,7 @@ def forward( alpha=1.0, compute_nll_loss=True, compiled=True, + is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx, @@ -60,12 +61,13 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, + is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -80,11 +82,16 @@ def __init__( alpha: float = 1.0, compute_nll_loss: bool = True, compiled: bool = True, + is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + alpha (float): Weight for the NLL loss. + compute_nll_loss (bool): Whether to compute NLL loss. + compiled (bool): Whether to compile the loss function. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -92,6 +99,7 @@ def __init__( self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled + self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCPOFunction.apply( @@ -104,4 +112,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.alpha, self.compute_nll_loss, self.compiled, + self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 5f1b17cf5..082036eb5 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -67,6 +67,7 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=True, + is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -83,12 +84,13 @@ def forward( ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, + is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -103,6 +105,7 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, use_ref_model: bool = False, + is_encoder_decoder: bool = False, ): """ Args: @@ -111,6 +114,7 @@ def __init__( compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -118,6 +122,7 @@ def __init__( self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model + self.is_encoder_decoder = is_encoder_decoder def forward( self, @@ -142,4 +147,5 @@ def forward( self.compute_nll_loss, self.compiled, self.use_ref_model, + self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..1ede7aca8 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -26,6 +26,7 @@ def forward( ignore_index=-100, alpha=1.0, beta=0.1, + is_encoder_decoder=False, compute_nll_loss=True, compiled=True, use_ref_model=False, @@ -56,6 +57,7 @@ def forward( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. @@ -94,6 +96,7 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, + is_encoder_decoder=is_encoder_decoder, **loss_kwargs, ) @@ -282,33 +285,48 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, + is_encoder_decoder=False, ): - len_chosen_chunk = target_chunk.shape[0] // 2 + # Calculate logits and log probabilities logits_chunk = input_chunk @ weight.t() if bias is not None: - logits_chunk = logits_chunk + bias + logits_chunk += bias log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + # Split chunk into chosen and rejected portions + len_chosen_chunk = target_chunk.shape[0] // 2 + + # Handle sequence shifting for non-encoder-decoder models + if not is_encoder_decoder: + logits_chunk = logits_chunk[:, :-1] + log_probs_chunk = log_probs_chunk[:, :-1] + target_chunk = target_chunk[:, 1:] + + # Calculate NLL loss for chosen sequences chosen_nll_loss = 0.0 if compute_nll_loss: + chosen_probs = log_probs_chunk[:len_chosen_chunk] + chosen_targets = target_chunk[:len_chosen_chunk] chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), + chosen_probs.reshape(-1, chosen_probs.shape[-1]), + chosen_targets.reshape(-1), reduction="sum", ignore_index=ignore_index, ) + # Calculate per-token log probabilities loss_mask = target_chunk != ignore_index label_chunk = torch.where(loss_mask, target_chunk, 0) - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( -1 ) average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] - + # Split results for chosen and rejected + chosen_logps, rejected_logps = ( + average_log_prob[:len_chosen_chunk], + average_log_prob[len_chosen_chunk:], + ) chosen_logits = logits_chunk[:len_chosen_chunk] rejected_logits = logits_chunk[len_chosen_chunk:] @@ -331,6 +349,7 @@ def _compute_loss( ignore_index=-100, alpha=1.0, beta=0.1, + is_encoder_decoder=False, compute_nll_loss=True, use_ref_model=False, ref_input_chunk=None, @@ -350,6 +369,7 @@ def _compute_loss( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). @@ -369,33 +389,43 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, + is_encoder_decoder=is_encoder_decoder, ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) + if not is_encoder_decoder: + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2, 1:] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] + ) + else: + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) if use_ref_model: with torch.no_grad(): - ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - ref_input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model + (ref_chosen_logps, ref_rejected_logps, _, _, _) = ( + LigerFusedLinearPreferenceBase.chunk_forward( + ref_input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model + is_encoder_decoder=is_encoder_decoder, # assume the ref model is the same family + ) ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index c860d4bd9..7dae8057e 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -57,6 +57,7 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, + is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -69,12 +70,13 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, + is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None + return *grads, None, None, None, None, None class LigerFusedLinearORPOLoss(torch.nn.Module): @@ -88,17 +90,22 @@ def __init__( beta: float = 0.1, compute_nll_loss: bool = True, compiled: bool = True, + is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. + compiled (bool): Whether to compile the loss function. + is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index self.beta = beta self.compute_nll_loss = compute_nll_loss self.compiled = compiled + self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearORPOFunction.apply( @@ -110,4 +117,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.beta, self.compute_nll_loss, self.compiled, + self.is_encoder_decoder, ) diff --git a/test/utils.py b/test/utils.py index 3d3799ad0..fc114d163 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,11 +350,13 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + is_encoder_decoder: bool = False, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.is_encoder_decoder = is_encoder_decoder @abstractmethod def alignment_loss(self): @@ -372,7 +374,6 @@ def get_batch_logps( logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - is_encoder_decoder: Whether the model is an encoder-decoder model. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ @@ -381,6 +382,9 @@ def get_batch_logps( "Logits (batch and sequence length dim) and labels must have the same shape." ) + if not self.is_encoder_decoder: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() loss_mask = labels != self.ignore_index # dummy token; we'll ignore the losses on these tokens later @@ -440,6 +444,9 @@ def concatenated_forward( def cross_entropy_loss(logits, labels): # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + if not self.is_encoder_decoder: + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism @@ -461,8 +468,12 @@ def cross_entropy_loss(logits, labels): chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1] + rejected_logits = all_logits[len_chosen:, :-1] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] return ( chosen_logps, From 7a781b7adf00f515d0d77552c7324fb7261baf51 Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Thu, 19 Dec 2024 13:18:23 -0800 Subject: [PATCH 4/7] Revert "fix chosen_nll_loss in chunked losses (#486)" (#489) This reverts commit 61eefe9a4429459351979dc7fe1de746fd7ca86f. ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/cpo_loss.py | 11 +-- src/liger_kernel/chunked_loss/dpo_loss.py | 8 +- .../chunked_loss/fused_linear_preference.py | 92 +++++++------------ src/liger_kernel/chunked_loss/orpo_loss.py | 10 +- test/utils.py | 17 +--- 5 files changed, 37 insertions(+), 101 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 1d771753e..2b8052e25 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -47,7 +47,6 @@ def forward( alpha=1.0, compute_nll_loss=True, compiled=True, - is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx, @@ -61,13 +60,12 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, - is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -82,16 +80,11 @@ def __init__( alpha: float = 1.0, compute_nll_loss: bool = True, compiled: bool = True, - is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. - alpha (float): Weight for the NLL loss. - compute_nll_loss (bool): Whether to compute NLL loss. - compiled (bool): Whether to compile the loss function. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -99,7 +92,6 @@ def __init__( self.alpha = alpha self.compute_nll_loss = compute_nll_loss self.compiled = compiled - self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCPOFunction.apply( @@ -112,5 +104,4 @@ def forward(self, lin_weight, _input, target, bias=None): self.alpha, self.compute_nll_loss, self.compiled, - self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 082036eb5..5f1b17cf5 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -67,7 +67,6 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=True, - is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -84,13 +83,12 @@ def forward( ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, - is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -105,7 +103,6 @@ def __init__( compute_nll_loss: bool = True, compiled: bool = True, use_ref_model: bool = False, - is_encoder_decoder: bool = False, ): """ Args: @@ -114,7 +111,6 @@ def __init__( compute_nll_loss (bool): Whether to compute the NLL loss. compiled (bool): Whether to use the torch compiled kernel. use_ref_model (bool): Whether to use a reference model for the DPO loss. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index @@ -122,7 +118,6 @@ def __init__( self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.use_ref_model = use_ref_model - self.is_encoder_decoder = is_encoder_decoder def forward( self, @@ -147,5 +142,4 @@ def forward( self.compute_nll_loss, self.compiled, self.use_ref_model, - self.is_encoder_decoder, ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 1ede7aca8..fff0791ec 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -26,7 +26,6 @@ def forward( ignore_index=-100, alpha=1.0, beta=0.1, - is_encoder_decoder=False, compute_nll_loss=True, compiled=True, use_ref_model=False, @@ -57,7 +56,6 @@ def forward( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. @@ -96,7 +94,6 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, - is_encoder_decoder=is_encoder_decoder, **loss_kwargs, ) @@ -285,48 +282,33 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, - is_encoder_decoder=False, ): - # Calculate logits and log probabilities + len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() if bias is not None: - logits_chunk += bias + logits_chunk = logits_chunk + bias log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) - # Split chunk into chosen and rejected portions - len_chosen_chunk = target_chunk.shape[0] // 2 - - # Handle sequence shifting for non-encoder-decoder models - if not is_encoder_decoder: - logits_chunk = logits_chunk[:, :-1] - log_probs_chunk = log_probs_chunk[:, :-1] - target_chunk = target_chunk[:, 1:] - - # Calculate NLL loss for chosen sequences chosen_nll_loss = 0.0 if compute_nll_loss: - chosen_probs = log_probs_chunk[:len_chosen_chunk] - chosen_targets = target_chunk[:len_chosen_chunk] chosen_nll_loss = F.nll_loss( - chosen_probs.reshape(-1, chosen_probs.shape[-1]), - chosen_targets.reshape(-1), + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), reduction="sum", ignore_index=ignore_index, ) - # Calculate per-token log probabilities loss_mask = target_chunk != ignore_index label_chunk = torch.where(loss_mask, target_chunk, 0) + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( -1 ) average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - # Split results for chosen and rejected - chosen_logps, rejected_logps = ( - average_log_prob[:len_chosen_chunk], - average_log_prob[len_chosen_chunk:], - ) + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + chosen_logits = logits_chunk[:len_chosen_chunk] rejected_logits = logits_chunk[len_chosen_chunk:] @@ -349,7 +331,6 @@ def _compute_loss( ignore_index=-100, alpha=1.0, beta=0.1, - is_encoder_decoder=False, compute_nll_loss=True, use_ref_model=False, ref_input_chunk=None, @@ -369,7 +350,6 @@ def _compute_loss( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. compute_nll_loss (bool): Whether to compute NLL loss. use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). @@ -389,43 +369,33 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, - is_encoder_decoder=is_encoder_decoder, ) - if not is_encoder_decoder: - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2, 1:] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * (input_chunk.shape[1] - 1) * weight.shape[0] - ) - else: - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) if use_ref_model: with torch.no_grad(): - (ref_chosen_logps, ref_rejected_logps, _, _, _) = ( - LigerFusedLinearPreferenceBase.chunk_forward( - ref_input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model - is_encoder_decoder=is_encoder_decoder, # assume the ref model is the same family - ) + ( + ref_chosen_logps, + ref_rejected_logps, + ref_chosen_logits, + ref_rejected_logits, + ref_chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + ref_input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 7dae8057e..c860d4bd9 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -57,7 +57,6 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, - is_encoder_decoder=False, ): return LigerFusedLinearPreferenceBase.forward( ctx=ctx, @@ -70,13 +69,12 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, - is_encoder_decoder=is_encoder_decoder, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None class LigerFusedLinearORPOLoss(torch.nn.Module): @@ -90,22 +88,17 @@ def __init__( beta: float = 0.1, compute_nll_loss: bool = True, compiled: bool = True, - is_encoder_decoder: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. - compute_nll_loss (bool): Whether to compute NLL loss. - compiled (bool): Whether to compile the loss function. - is_encoder_decoder (bool): Whether the model is an encoder-decoder model. """ super().__init__() self.ignore_index = ignore_index self.beta = beta self.compute_nll_loss = compute_nll_loss self.compiled = compiled - self.is_encoder_decoder = is_encoder_decoder def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearORPOFunction.apply( @@ -117,5 +110,4 @@ def forward(self, lin_weight, _input, target, bias=None): self.beta, self.compute_nll_loss, self.compiled, - self.is_encoder_decoder, ) diff --git a/test/utils.py b/test/utils.py index fc114d163..3d3799ad0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,13 +350,11 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, - is_encoder_decoder: bool = False, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model - self.is_encoder_decoder = is_encoder_decoder @abstractmethod def alignment_loss(self): @@ -374,6 +372,7 @@ def get_batch_logps( logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ @@ -382,9 +381,6 @@ def get_batch_logps( "Logits (batch and sequence length dim) and labels must have the same shape." ) - if not self.is_encoder_decoder: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() loss_mask = labels != self.ignore_index # dummy token; we'll ignore the losses on these tokens later @@ -444,9 +440,6 @@ def concatenated_forward( def cross_entropy_loss(logits, labels): # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) - if not self.is_encoder_decoder: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism @@ -468,12 +461,8 @@ def cross_entropy_loss(logits, labels): chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] - if not self.is_encoder_decoder: - chosen_logits = all_logits[:len_chosen, :-1] - rejected_logits = all_logits[len_chosen:, :-1] - else: - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] return ( chosen_logps, From 3205342a6a7209c55ca3a4bd97e986961fdc792e Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Thu, 19 Dec 2024 16:49:14 -0800 Subject: [PATCH 5/7] fix dpo tests: reduce tolerance and change default compute_nll_loss false (#490) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/dpo_loss.py | 4 +- test/chunked_loss/test_dpo_loss.py | 69 ++++++++++++++++++++--- test/utils.py | 10 +++- 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 5f1b17cf5..cf07e186e 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -64,7 +64,7 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, - compute_nll_loss=True, + compute_nll_loss=False, compiled=True, use_ref_model=True, ): @@ -100,7 +100,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, - compute_nll_loss: bool = True, + compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = False, ): diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 0ac8faeb8..b73a69a57 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -23,10 +23,17 @@ class HFDPOLoss(HFAlignmentLoss): """ def __init__( - self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, ): super().__init__( - beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, ) def alignment_loss( @@ -61,6 +68,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -72,7 +80,10 @@ def __init__( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) self.dpo_loss = HFDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ).get_batch_loss_metrics def forward(self, x, ref_x, y): @@ -95,6 +106,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -106,7 +118,10 @@ def __init__( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) self.dpo_loss = LigerFusedLinearDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ) def forward(self, x, ref_x, y): @@ -132,14 +147,27 @@ def forward(self, x, ref_x, y): "scalar, dtype, atol, rtol", [ (1.0, torch.bfloat16, 5e-2, 5e-1), - (1.0, torch.float32, 2e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), ], ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) @pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ref_bias, + compute_nll_loss, + ignore_index, + beta, ): B = 2 * B # dpo loss requires B to be even @@ -149,6 +177,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -158,6 +187,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -251,7 +281,10 @@ def test_correctness( ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) -def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +def test_correctness_functional( + B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss +): B = 2 * B _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar @@ -290,10 +323,28 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply( - input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1 + input1, + weight1, + target, + bias1, + ref_input, + ref_weight1, + ref_bias1, + -100, + 0.1, + compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( - input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2 + input2, + weight2, + target, + bias2, + ref_input, + ref_weight2, + ref_bias2, + -100, + 0.1, + compute_nll_loss, ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index 3d3799ad0..48fcf3601 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,11 +350,13 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + compute_nll_loss: bool = True, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.compute_nll_loss = compute_nll_loss @abstractmethod def alignment_loss(self): @@ -448,9 +450,11 @@ def cross_entropy_loss(logits, labels): return loss labels = target - chosen_nll_loss = cross_entropy_loss( - all_logits[:len_chosen], labels[:len_chosen] - ) + chosen_nll_loss = torch.tensor(0.0, device=all_logits.device) + if self.compute_nll_loss: + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], labels[:len_chosen] + ) all_logps = self.get_batch_logps( all_logits, From 79e2b02a4a4ffafe111c3f8ede7df4fb56db890e Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:08:18 -0800 Subject: [PATCH 6/7] CPO & SimPO add label_smoothing (#493) ## Summary Add label_smoothing support for CPO and SimPO so that they align with the huggingface [interface](https://github.com/huggingface/trl/blob/b668048fe1931c57796ad5ae3f10852337ce7565/trl/trainer/cpo_trainer.py#L645C1-L658C14). ## Testing Done - [x] Something wrong with the unit test. I'll have to fix it - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Mecoli1219 --- src/liger_kernel/chunked_loss/cpo_loss.py | 18 ++++++++++++--- src/liger_kernel/chunked_loss/simpo_loss.py | 21 ++++++++++++++--- test/chunked_loss/test_cpo_loss.py | 25 +++++++++++++++++++-- test/chunked_loss/test_simpo_loss.py | 24 ++++++++++++++++++-- 4 files changed, 78 insertions(+), 10 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..987f0cdcf 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -9,7 +9,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + def preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0 + ): """ Paper: https://arxiv.org/pdf/2401.08417 @@ -30,9 +32,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). full_target (torch.Tensor): Non chunked full target tensor beta (float): Weight for the CPO loss + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + F.logsigmoid(logits) * (1 - label_smoothing) + + F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -45,6 +52,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=True, compiled=True, ): @@ -58,6 +66,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -65,7 +74,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -78,6 +87,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, ): @@ -90,6 +100,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled @@ -102,6 +113,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, ) diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 7efa0603d..2dc9f1a6b 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -10,7 +10,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5 + chosen_logps, + rejected_logps, + full_target, + beta=0.1, + gamma=0.5, + label_smoothing=0.0, ): """ Paper: https://arxiv.org/pdf/2405.14734 @@ -33,9 +38,14 @@ def preference_loss_fn( full_target: Non chunked full target tensor beta (float): beta weight gamma (float): gemma margin term + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - gamma - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + F.logsigmoid(logits) * (1 - label_smoothing) + + F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -48,6 +58,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=False, compiled=True, gamma=0.5, @@ -63,6 +74,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compiled=compiled, gamma=gamma, ) @@ -70,7 +82,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None class LigerFusedLinearSimPOLoss(torch.nn.Module): @@ -83,6 +95,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, gamma: float = 0.5, @@ -96,6 +109,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.gamma = gamma @@ -109,6 +123,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, self.gamma, diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index f0fef7734..a0c4050e5 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -86,6 +86,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, loss_type: str = "sigmoid", simpo_gamma: float = 0.5, ): @@ -97,6 +98,7 @@ def __init__( ignore_index=ignore_index, beta=beta, loss_type=loss_type, + label_smoothing=label_smoothing, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics @@ -114,13 +116,17 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, ): super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.cpo_loss = LigerFusedLinearCPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -145,8 +151,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + alpha, + label_smoothing, ): B = 2 * B # cpo loss requires B to be even @@ -157,6 +176,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) liger_lm_head_cpo = LigerLMHeadCPO( H=H, @@ -165,6 +185,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 3d0937c27..eede598fe 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -25,6 +25,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, gamma: float = 0.5, ): super().__init__() @@ -32,7 +33,11 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.simpo_loss = LigerFusedLinearSimPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + gamma=gamma, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -57,8 +62,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + gamma, + label_smoothing, ): B = 2 * B # SimPO loss requires B to be even @@ -70,6 +88,7 @@ def test_correctness( ignore_index=ignore_index, beta=beta, loss_type="simpo", + label_smoothing=label_smoothing, simpo_gamma=gamma, ) liger_lm_head_simpo = LigerLMHeadSimPO( @@ -79,6 +98,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, gamma=gamma, ) From 15a2f58f06b1972d9d23ad898608398df7a421b0 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Sat, 21 Dec 2024 07:17:38 +0800 Subject: [PATCH 7/7] Fix Preference Loss and Refactor for Readability (#484) ## Summary Thanks to @winglian and @shivam15s noticed and fixed this https://github.com/linkedin/Liger-Kernel/pull/481. This PR suggests negating the preference loss terms to align with the formulas in the docstrings, while maintaining the base preference structure as `nll_loss + preference_loss`. This would make our loss computations more consistent since both terms would represent losses to be minimized. [UPDATE: It seems like being addressed now in [here](https://github.com/linkedin/Liger-Kernel/commit/3205342a6a7209c55ca3a4bd97e986961fdc792e#diff-3048cb37b97e27515852c200994f3257b8ae33a465421d05184713377c0895b1R150)] This PR also tightened the tolerance in case of encountering a similar issue. ## Testing Done - Hardware Type: - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu Co-authored-by: Wing Lian Co-authored-by: Shivam Sahni --- src/liger_kernel/chunked_loss/cpo_loss.py | 4 ++-- src/liger_kernel/chunked_loss/fused_linear_preference.py | 2 +- src/liger_kernel/chunked_loss/orpo_loss.py | 2 +- src/liger_kernel/chunked_loss/simpo_loss.py | 4 ++-- test/chunked_loss/test_cpo_loss.py | 8 ++++---- test/chunked_loss/test_orpo_loss.py | 2 +- test/utils.py | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 987f0cdcf..dd84a4dbf 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -36,8 +36,8 @@ def preference_loss_fn( """ logits = beta * (chosen_logps - rejected_logps) loss = ( - F.logsigmoid(logits) * (1 - label_smoothing) - + F.logsigmoid(-logits) * label_smoothing + - F.logsigmoid(logits) * (1 - label_smoothing) + - F.logsigmoid(-logits) * label_smoothing ).sum() / (full_target.shape[0] // 2) return loss diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..4eb939a79 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -408,7 +408,7 @@ def _compute_loss( else: preference_loss, aux_outputs = preference_loss_outputs, [] - loss = alpha * chosen_nll_loss - preference_loss + loss = alpha * chosen_nll_loss + preference_loss return_vars = ( chosen_logps, rejected_logps, diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index c860d4bd9..d615212c5 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -36,7 +36,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): - torch.log1p(-torch.exp(rejected_logps)) ) ratio = F.logsigmoid(log_odds) - loss = beta * ratio.sum() / (full_target.shape[0] // 2) + loss = -beta * ratio.sum() / (full_target.shape[0] // 2) chosen_rewards = beta * chosen_logps rejected_rewards = beta * rejected_logps diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 2dc9f1a6b..5d5867252 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -42,8 +42,8 @@ def preference_loss_fn( """ logits = beta * (chosen_logps - rejected_logps) - gamma loss = ( - F.logsigmoid(logits) * (1 - label_smoothing) - + F.logsigmoid(-logits) * label_smoothing + - F.logsigmoid(logits) * (1 - label_smoothing) + - F.logsigmoid(-logits) * label_smoothing ).sum() / (full_target.shape[0] // 2) return loss diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index a0c4050e5..4090db795 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -60,14 +60,14 @@ def alignment_loss( if self.loss_type == "sigmoid": # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) elif self.loss_type == "simpo": logits = logits - (self.simpo_gamma / self.beta) losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) else: raise ValueError( diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 9f5d81b18..112d4f05c 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -57,7 +57,7 @@ def alignment_loss( - torch.log1p(-torch.exp(policy_rejected_logps)) ) ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio + losses = -self.beta * ratio chosen_rewards = self.beta * policy_chosen_logps rejected_rewards = self.beta * policy_rejected_logps diff --git a/test/utils.py b/test/utils.py index 48fcf3601..3d08c4ae3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -515,7 +515,7 @@ def get_batch_loss_metrics( else: losses, aggregated_aux_outputs = alignment_loss_outputs, [] # full loss - loss = policy_nll_loss * self.alpha - losses.mean() + loss = policy_nll_loss * self.alpha + losses.mean() return_vars = ( policy_chosen_logps, policy_rejected_logps,