Skip to content

Commit

Permalink
Merge branch 'main' into log-completion
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored Feb 7, 2025
2 parents 0c271b2 + 84d73fd commit 818f864
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 3 deletions.
55 changes: 55 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from trl.trainer.utils import (
DataCollatorForChatML,
batch_generation,
compute_token_accuracy,
decode_and_strip_padding,
flush_left,
generate_model_card,
Expand Down Expand Up @@ -451,3 +452,57 @@ def test_no_tensors(self):
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])

self.assertTrue(torch.equal(new_mask, expected_mask))


class TestComputeTokenAccuracy(unittest.TestCase):
def test_basic_accuracy(self):
# Test basic accuracy computation
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2]
labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2]
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens

def test_with_ignore_index(self):
# Test accuracy computation with ignored tokens
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]])
labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored
accuracy = compute_token_accuracy(logits, labels, ignore_index=-100)
self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens

def test_all_ignored(self):
# Test case where all tokens are ignored
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[-100, -100]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy

def test_perfect_accuracy(self):
# Test case with 100% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[1, 0]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 1.0) # All predictions correct

def test_zero_accuracy(self):
# Test case with 0% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[0, 1]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # All predictions wrong

def test_batch_accuracy(self):
# Test accuracy computation across multiple batches
logits = torch.tensor(
[
[[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1
[[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2
]
)
labels = torch.tensor(
[
[1, 0, 1], # Batch 1
[1, 0, -100], # Batch 2 (last token ignored)
]
)
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.8)
4 changes: 2 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"XPOTrainer",
],
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"],
}

try:
Expand Down Expand Up @@ -200,7 +200,7 @@
XPOTrainer,
)
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config

try:
if not is_diffusers_available():
Expand Down
3 changes: 2 additions & 1 deletion trl/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def main(script_args, training_args, model_args):
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"disable_dropout_in_model",
"empty_cache",
"peft_module_casting_to_bf16",
"compute_token_accuracy",
],
"xpo_config": ["XPOConfig"],
"xpo_trainer": ["XPOTrainer"],
Expand Down Expand Up @@ -144,6 +145,7 @@
DataCollatorForCompletionOnlyLM,
RunningMoments,
compute_accuracy,
compute_token_accuracy,
disable_dropout_in_model,
empty_cache,
peft_module_casting_to_bf16,
Expand Down
46 changes: 46 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
import inspect
import os
import warnings
from collections import defaultdict
from typing import Callable, Optional, Union

import datasets
import torch
import torch.nn as nn
import transformers
from accelerate.state import PartialState
from datasets import Dataset
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
from packaging import version
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Expand All @@ -48,6 +51,7 @@
from .utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
compute_token_accuracy,
generate_model_card,
get_comet_experiment_url,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -304,6 +308,9 @@ def make_inputs_require_grad(module, input, output):
UserWarning,
)

# Initialize the metrics
self._metrics = defaultdict(list)

super().__init__(
model=model,
args=args,
Expand Down Expand Up @@ -546,3 +553,42 @@ def create_model_card(
)

model_card.save(os.path.join(self.args.output_dir, "README.md"))

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Compute training loss and additionally compute token accuracies
"""
(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)

# Compute token accuracy if we have labels
if "labels" in inputs:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()

# Gather logits and labels from all GPUs first
shift_logits = self.accelerator.gather_for_metrics(shift_logits)
shift_labels = self.accelerator.gather_for_metrics(shift_labels)

# Then compute accuracy on the gathered tensors
if self.accelerator.is_main_process:
accuracy = compute_token_accuracy(shift_logits, shift_labels)
self._metrics["mean_token_accuracy"].append(accuracy)

return (loss, outputs) if return_outputs else loss

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics

# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if next(iter(logs.keys())).startswith("eval_"):
metrics = {f"eval_{key}": val for key, val in metrics.items()}

logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics.clear()
21 changes: 21 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,3 +1647,24 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
return mask
else:
return mask, *tensors


def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
"""
Compute the mean token accuracy.
"""
# Get predictions
predictions = logits.argmax(dim=-1)

# Create mask for non-padding tokens (assuming pad_token_id is ignore_index)
mask = labels != ignore_index

# Calculate accuracy only on non-padding tokens
correct_predictions = (predictions == labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()

# Calculate accuracy
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0

return accuracy

0 comments on commit 818f864

Please sign in to comment.