diff --git a/test/utils.py b/test/utils.py index e8383d659..f7ec42f0f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -483,7 +483,7 @@ def get_batch_loss_metrics( ref_bias: torch.FloatTensor = None, average_log_prob: bool = True, ): - """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + """Compute the loss metrics for the given batch of inputs for train or test.""" forward_output = self.concatenated_forward( _input, weight, target, bias, average_log_prob