diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 16ab05ab26..75b87c7708 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import tempfile import unittest @@ -20,7 +22,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments from transformers.testing_utils import require_wandb -from trl import BasePairwiseJudge, WinRateCallback +from trl import BasePairwiseJudge, LogCompletionsCallback, WinRateCallback class HalfPairwiseJudge(BasePairwiseJudge): @@ -35,14 +37,17 @@ def judge(self, prompts, completions, shuffle_order=True): class TrainerWithRefModel(Trainer): # This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional # ref_model attribute - def __init__(self, model, ref_model, args, train_dataset, eval_dataset, tokenizer): + def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class): super().__init__( - model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, ) self.ref_model = ref_model -@require_wandb class WinRateCallbackTester(unittest.TestCase): def setUp(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") @@ -52,6 +57,7 @@ def setUp(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") dataset["train"] = dataset["train"].select(range(8)) self.expected_winrates = [ + {"eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, {"eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, {"eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, {"eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, @@ -86,7 +92,7 @@ def test_basic(self): args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], - tokenizer=self.tokenizer, + processing_class=self.tokenizer, ) win_rate_callback = WinRateCallback( judge=self.judge, trainer=trainer, generation_config=self.generation_config @@ -112,7 +118,7 @@ def test_without_ref_model(self): args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], - tokenizer=self.tokenizer, + processing_class=self.tokenizer, ) win_rate_callback = WinRateCallback( judge=self.judge, trainer=trainer, generation_config=self.generation_config @@ -145,7 +151,7 @@ def test_lora(self): args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], - tokenizer=self.tokenizer, + processing_class=self.tokenizer, ) win_rate_callback = WinRateCallback( judge=self.judge, trainer=trainer, generation_config=self.generation_config @@ -154,3 +160,59 @@ def test_lora(self): trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] self.assertListEqual(winrate_history, self.expected_winrates) + + +@require_wandb +class LogCompletionsCallbackTester(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") + self.tokenizer.pad_token = self.tokenizer.eos_token + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + dataset["train"] = dataset["train"].select(range(8)) + + def tokenize_function(examples): + out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) + out["labels"] = out["input_ids"].copy() + return out + + self.dataset = dataset.map(tokenize_function, batched=True) + + self.generation_config = GenerationConfig(max_length=32) + + def test_basic(self): + import wandb + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="wandb", + ) + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) + trainer.add_callback(completions_callback) + trainer.train() + + # Get the current run + completions_path = wandb.run.summary.completions["path"] + json_path = os.path.join(wandb.run.dir, completions_path) + with open(json_path) as f: + completions = json.load(f) + + # Check that the columns are correct + self.assertIn("step", completions["columns"]) + self.assertIn("prompt", completions["columns"]) + self.assertIn("completion", completions["columns"]) + + # Check that the prompt is in the log + self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0]) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index a4616ba30e..a920f41cc4 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -255,7 +255,7 @@ def __init__( def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): # When the trainer is initialized, we generate completions for the reference model. - tokenizer = kwargs["tokenizer"] + tokenizer = kwargs["processing_class"] tokenizer.padding_side = "left" accelerator = self.trainer.accelerator # Use the reference model if available, otherwise use the initial model @@ -307,7 +307,7 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra # At every evaluation step, we generate completions for the model and compare them with the reference # completions that have been generated at the beginning of training. We then compute the win rate and log it to # the trainer. - tokenizer = kwargs["tokenizer"] + tokenizer = kwargs["processing_class"] tokenizer.padding_side = "left" accelerator = self.trainer.accelerator model = self.trainer.model_wrapped @@ -401,7 +401,7 @@ def on_step_end(self, args, state, control, **kwargs): if state.global_step % freq != 0: return - tokenizer = kwargs["tokenizer"] + tokenizer = kwargs["processing_class"] tokenizer.padding_side = "left" accelerator = self.trainer.accelerator model = self.trainer.model_wrapped