Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fpgmaas committed Jul 4, 2024
1 parent 90d2daf commit c88dd4e
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ def encode(
device = self.device

self.to(device)
all_embeddings = []

length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
Expand All @@ -562,15 +563,16 @@ def encode(
out_features["sentence_embedding"], self.truncate_dim
)

all_embeddings: list = []
if output_value == "token_embeddings":
all_embeddings.extend(self._process_token_embeddings(out_features))
elif output_value == "sentence_embeddings":
all_embeddings.extend(
self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy)
)
elif not output_value:
all_embeddings.extend(self._process_all_outputs(out_features))
if output_value == "token_embeddings":
all_embeddings.extend(self._process_token_embeddings(out_features))
elif output_value == "sentence_embedding":
all_embeddings.extend(
self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy)
)
elif not output_value:
all_embeddings.extend(self._process_all_outputs(out_features))
else:
raise ValueError(f"Got unexpected value for 'output_value' : {output_value}")

all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

Expand Down Expand Up @@ -644,7 +646,8 @@ def _process_all_outputs(out_features):
embeddings.append(row)
return embeddings

def _process_sentence_embeddings(self, out_features, normalize_embeddings, convert_to_numpy):
@staticmethod
def _process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy):
embeddings = out_features["sentence_embedding"]
embeddings = embeddings.detach()
if normalize_embeddings:
Expand Down

0 comments on commit c88dd4e

Please sign in to comment.