Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
  • Loading branch information
fpgmaas committed Jul 4, 2024
1 parent 9deb872 commit 90d2daf
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,15 @@ def encode(
out_features["sentence_embedding"], self.truncate_dim
)

all_embeddings = self._process_embeddings(out_features, output_value)
all_embeddings: list = []
if output_value == "token_embeddings":
all_embeddings.extend(self._process_token_embeddings(out_features))
elif output_value == "sentence_embeddings":
all_embeddings.extend(
self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy)
)
elif not output_value:
all_embeddings.extend(self._process_all_outputs(out_features))

all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

Expand Down Expand Up @@ -636,23 +644,15 @@ def _process_all_outputs(out_features):
embeddings.append(row)
return embeddings

def _process_sentence_embeddings(self, out_features):
def _process_sentence_embeddings(self, out_features, normalize_embeddings, convert_to_numpy):
embeddings = out_features["sentence_embedding"]
embeddings = embeddings.detach()
if self.normalize_embeddings:
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
if self.convert_to_numpy:
if convert_to_numpy:
embeddings = embeddings.cpu()
return embeddings

def _process_embeddings(self, out_features, output_value):
if output_value == "token_embeddings":
return self._process_token_embeddings(out_features)
elif output_value is None:
return self._process_all_outputs(out_features)
else:
return self._process_sentence_embeddings(out_features)

@property
def similarity_fn_name(self) -> Optional[str]:
"""Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
Expand Down

0 comments on commit 90d2daf

Please sign in to comment.