diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 5ec62afd49..068f3ebfe0 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -53,7 +53,7 @@ def _generate_completions( batch_size: int = 1, ) -> List[str]: """ - Generates completions for a list of pre-formatted prompts. + Generates completions for a list of pre-formatted prompts from the given model. Args: prompts (List[str]): A list of input prompts for which completions are to be generated. @@ -68,7 +68,6 @@ def _generate_completions( """ completions = [] with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - unwrapped_model.eval() for idx in range(0, len(prompts), batch_size): batch = prompts[idx : idx + batch_size] tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) @@ -81,7 +80,6 @@ def _generate_completions( generation = generation[len(prompt) :] completion = tokenizer.decode(generation, skip_special_tokens=True) completions.append(completion) - unwrapped_model.train() return completions