From bf9eb803ce2dda26a8ef903c33d80cd1fcb55a3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arthur=20C=C3=A2mara?= Date: Thu, 26 Sep 2024 15:06:12 +0000 Subject: [PATCH] add prompt to test dataset --- sentence_transformers/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 3bc09e65c..329217262 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -781,6 +781,9 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader dataloader_params["batch_sampler"] = batch_sampler elif isinstance(eval_dataset, Dataset): + if self.prompt is not None: + eval_dataset = self.add_prompts_to_dataset(eval_dataset) + batch_sampler = self.get_batch_sampler( eval_dataset, batch_size=self.args.eval_batch_size, @@ -789,8 +792,6 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader generator=generator, ) dataloader_params["batch_sampler"] = batch_sampler - if self.prompt is not None: - eval_dataset = self.add_prompts_to_dataset(eval_dataset, self.prompt) else: raise ValueError( @@ -848,6 +849,8 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: raise ValueError( "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset." ) + if self.prompt is not None: + test_dataset[dataset_name] = self.add_prompts_to_dataset(dataset, dataset_name) if isinstance(self.loss, dict): test_dataset = self.add_dataset_name_column(test_dataset) batch_samplers = [ @@ -873,6 +876,8 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: elif isinstance(test_dataset, Dataset): self.validate_column_names(test_dataset) + if self.prompt is not None: + test_dataset = self.add_prompts_to_dataset(test_dataset) batch_sampler = self.get_batch_sampler( test_dataset, batch_size=self.args.eval_batch_size,