diff --git a/sentence_transformers/datasets/NoDuplicatesDataLoader.py b/sentence_transformers/datasets/NoDuplicatesDataLoader.py index d53dbb703..3d2f53cd0 100644 --- a/sentence_transformers/datasets/NoDuplicatesDataLoader.py +++ b/sentence_transformers/datasets/NoDuplicatesDataLoader.py @@ -24,14 +24,14 @@ def __iter__(self): valid_example = True for text in example.texts: - if text.strip().lower() in texts_in_batch: + if self._text_to_str(text).strip().lower() in texts_in_batch: valid_example = False break if valid_example: batch.append(example) for text in example.texts: - texts_in_batch.add(text.strip().lower()) + texts_in_batch.add(self._text_to_str(text).strip().lower()) self.data_pointer += 1 if self.data_pointer >= len(self.train_examples): @@ -41,4 +41,18 @@ def __iter__(self): yield self.collate_fn(batch) if self.collate_fn is not None else batch def __len__(self): - return math.floor(len(self.train_examples) / self.batch_size) \ No newline at end of file + return math.floor(len(self.train_examples) / self.batch_size) + + + def _text_to_str(self, text) -> str: + """ + In symmetric models, the text is a string, but in asymmetric models the text is a dictionary. + This method extracts a string in both cases. + """ + + # there is only one key value. Example is : `{"query": "Some query"}` + if isinstance(text, dict): + return list(text.values())[0] + else: + return text +