Skip to content

Commit

Permalink
add typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
fpgmaas committed Jul 9, 2024
1 parent ff20b4f commit 10dff7a
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,10 @@ def encode(
return all_embeddings

@staticmethod
def _pad_features(features):
def _pad_features(features: Dict[str, torch.Tensor]) -> None:
"""
Pads the input features to the next power of 2 for compatibility with certain hardware accelerators.
"""
curr_tokenize_len = features["input_ids"].shape
additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]
features["input_ids"] = torch.cat(
Expand All @@ -631,7 +634,7 @@ def _pad_features(features):
)

@staticmethod
def _process_token_embeddings(out_features):
def _process_token_embeddings(out_features: Dict[str, torch.Tensor]) -> List[torch.Tensor]:
embeddings = []
for token_emb, attention in zip(out_features["token_embeddings"], out_features["attention_mask"]):
last_mask_id = len(attention) - 1
Expand All @@ -641,15 +644,17 @@ def _process_token_embeddings(out_features):
return embeddings

@staticmethod
def _process_all_outputs(out_features):
def _process_all_outputs(out_features: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
return embeddings

@staticmethod
def _process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy):
def _process_sentence_embeddings(
out_features: Dict[str, torch.Tensor], normalize_embeddings: bool, convert_to_numpy: bool
) -> Union[List[np.ndarray], List[torch.Tensor]]:
embeddings = out_features["sentence_embedding"]
embeddings = embeddings.detach()
if normalize_embeddings:
Expand Down

0 comments on commit 10dff7a

Please sign in to comment.