diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index eea804139..cda70ce8a 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -537,46 +537,25 @@ def encode( device = self.device self.to(device) - all_embeddings = [] + length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): sentences_batch = sentences_sorted[start_index : start_index + batch_size] features = self.tokenize(sentences_batch) + if self.device.type == "hpu": if "input_ids" in features: - curr_tokenize_len = features["input_ids"].shape - additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] - features["input_ids"] = torch.cat( - ( - features["input_ids"], - torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), - ), - -1, - ) - features["attention_mask"] = torch.cat( - ( - features["attention_mask"], - torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), - ), - -1, - ) - if "token_type_ids" in features: - features["token_type_ids"] = torch.cat( - ( - features["token_type_ids"], - torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), - ), - -1, - ) + self._pad_features(features) features = batch_to_device(features, device) features.update(extra_features) with torch.no_grad(): out_features = self.forward(features) + if self.device.type == "hpu": out_features = copy.deepcopy(out_features) @@ -585,29 +564,17 @@ def encode( ) if output_value == "token_embeddings": - embeddings = [] - for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): - last_mask_id = len(attention) - 1 - while last_mask_id > 0 and attention[last_mask_id].item() == 0: - last_mask_id -= 1 - - embeddings.append(token_emb[0 : last_mask_id + 1]) - elif output_value is None: # Return all outputs - embeddings = [] - for sent_idx in range(len(out_features["sentence_embedding"])): - row = {name: out_features[name][sent_idx] for name in out_features} - embeddings.append(row) - else: # Sentence embeddings - embeddings = out_features[output_value] - embeddings = embeddings.detach() - if normalize_embeddings: - embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) - - # fixes for #522 and #487 to avoid oom problems on gpu with large datasets - if convert_to_numpy: - embeddings = embeddings.cpu() - - all_embeddings.extend(embeddings) + all_embeddings.extend(self._process_token_embeddings(out_features)) + elif output_value == "sentence_embedding": + all_embeddings.extend( + self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy) + ) + elif not output_value: + all_embeddings.extend(self._process_all_outputs(out_features)) + else: + raise ValueError( + f"Got unexpected value for 'output_value' : {output_value}. Valid values are 'token_embeddings', 'sentence_embedding' or None." + ) all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] @@ -636,6 +603,66 @@ def encode( return all_embeddings + @staticmethod + def _pad_features(features: Dict[str, torch.Tensor]) -> None: + """ + Pads the input features to the next power of 2 for compatibility with certain hardware accelerators. + """ + curr_tokenize_len = features["input_ids"].shape + additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] + features["input_ids"] = torch.cat( + ( + features["input_ids"], + torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + features["attention_mask"] = torch.cat( + ( + features["attention_mask"], + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + if "token_type_ids" in features: + features["token_type_ids"] = torch.cat( + ( + features["token_type_ids"], + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + + @staticmethod + def _process_token_embeddings(out_features: Dict[str, torch.Tensor]) -> List[torch.Tensor]: + embeddings = [] + for token_emb, attention in zip(out_features["token_embeddings"], out_features["attention_mask"]): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + embeddings.append(token_emb[0 : last_mask_id + 1]) + return embeddings + + @staticmethod + def _process_all_outputs(out_features: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: + embeddings = [] + for sent_idx in range(len(out_features["sentence_embedding"])): + row = {name: out_features[name][sent_idx] for name in out_features} + embeddings.append(row) + return embeddings + + @staticmethod + def _process_sentence_embeddings( + out_features: Dict[str, torch.Tensor], normalize_embeddings: bool, convert_to_numpy: bool + ) -> Union[List[np.ndarray], List[torch.Tensor]]: + embeddings = out_features["sentence_embedding"] + embeddings = embeddings.detach() + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + if convert_to_numpy: + embeddings = embeddings.cpu() + return embeddings + @property def similarity_fn_name(self) -> Optional[str]: """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.