From 10dff7a32ee68d2f8a206f21a382c299a5212227 Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Tue, 9 Jul 2024 14:45:36 +0200 Subject: [PATCH] add typehints --- sentence_transformers/SentenceTransformer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 412dd253f..cda70ce8a 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -604,7 +604,10 @@ def encode( return all_embeddings @staticmethod - def _pad_features(features): + def _pad_features(features: Dict[str, torch.Tensor]) -> None: + """ + Pads the input features to the next power of 2 for compatibility with certain hardware accelerators. + """ curr_tokenize_len = features["input_ids"].shape additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] features["input_ids"] = torch.cat( @@ -631,7 +634,7 @@ def _pad_features(features): ) @staticmethod - def _process_token_embeddings(out_features): + def _process_token_embeddings(out_features: Dict[str, torch.Tensor]) -> List[torch.Tensor]: embeddings = [] for token_emb, attention in zip(out_features["token_embeddings"], out_features["attention_mask"]): last_mask_id = len(attention) - 1 @@ -641,7 +644,7 @@ def _process_token_embeddings(out_features): return embeddings @staticmethod - def _process_all_outputs(out_features): + def _process_all_outputs(out_features: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: embeddings = [] for sent_idx in range(len(out_features["sentence_embedding"])): row = {name: out_features[name][sent_idx] for name in out_features} @@ -649,7 +652,9 @@ def _process_all_outputs(out_features): return embeddings @staticmethod - def _process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy): + def _process_sentence_embeddings( + out_features: Dict[str, torch.Tensor], normalize_embeddings: bool, convert_to_numpy: bool + ) -> Union[List[np.ndarray], List[torch.Tensor]]: embeddings = out_features["sentence_embedding"] embeddings = embeddings.detach() if normalize_embeddings: