From bb0afc245984efff6ce539e7c6b6c40535d9b670 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 20 Nov 2024 11:24:41 +0100 Subject: [PATCH] remove redunant call to eval and train (#2372) --- trl/trainer/callbacks.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 5ec62afd49..068f3ebfe0 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -53,7 +53,7 @@ def _generate_completions( batch_size: int = 1, ) -> List[str]: """ - Generates completions for a list of pre-formatted prompts. + Generates completions for a list of pre-formatted prompts from the given model. Args: prompts (List[str]): A list of input prompts for which completions are to be generated. @@ -68,7 +68,6 @@ def _generate_completions( """ completions = [] with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - unwrapped_model.eval() for idx in range(0, len(prompts), batch_size): batch = prompts[idx : idx + batch_size] tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) @@ -81,7 +80,6 @@ def _generate_completions( generation = generation[len(prompt) :] completion = tokenizer.decode(generation, skip_special_tokens=True) completions.append(completion) - unwrapped_model.train() return completions