From c88dd4ea091058f5d29503c1c8ad732fb2362d82 Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Thu, 4 Jul 2024 13:27:31 +0200 Subject: [PATCH] small fix --- sentence_transformers/SentenceTransformer.py | 23 +++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 94dbb1f5f..54538aac4 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -537,6 +537,7 @@ def encode( device = self.device self.to(device) + all_embeddings = [] length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] @@ -562,15 +563,16 @@ def encode( out_features["sentence_embedding"], self.truncate_dim ) - all_embeddings: list = [] - if output_value == "token_embeddings": - all_embeddings.extend(self._process_token_embeddings(out_features)) - elif output_value == "sentence_embeddings": - all_embeddings.extend( - self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy) - ) - elif not output_value: - all_embeddings.extend(self._process_all_outputs(out_features)) + if output_value == "token_embeddings": + all_embeddings.extend(self._process_token_embeddings(out_features)) + elif output_value == "sentence_embedding": + all_embeddings.extend( + self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy) + ) + elif not output_value: + all_embeddings.extend(self._process_all_outputs(out_features)) + else: + raise ValueError(f"Got unexpected value for 'output_value' : {output_value}") all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] @@ -644,7 +646,8 @@ def _process_all_outputs(out_features): embeddings.append(row) return embeddings - def _process_sentence_embeddings(self, out_features, normalize_embeddings, convert_to_numpy): + @staticmethod + def _process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy): embeddings = out_features["sentence_embedding"] embeddings = embeddings.detach() if normalize_embeddings: