diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index dfcc297c8..a6df287e1 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -363,6 +363,7 @@ def encode( convert_to_tensor: bool = False, device: str = None, normalize_embeddings: bool = False, + kwargs: Optional[Dict[str, Any]] = None, ) -> Union[List[Tensor], ndarray, Tensor]: """ Computes sentence embeddings. @@ -485,11 +486,17 @@ def encode( if self.device.type == "hpu": if "input_ids" in features: curr_tokenize_len = features["input_ids"].shape - additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] + if curr_tokenize_len[1] > 4096: + additional_pad_len = math.ceil(curr_tokenize_len[1] / 128) * 128 - curr_tokenize_len[1] + + extra_features.update(kwargs["hpu_kwargs"]) + else: + additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] + features["input_ids"] = torch.cat( ( features["input_ids"], - torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), ), -1, ) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index d5a670869..84080a512 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -4,6 +4,7 @@ from torch import nn from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config +from sentence_transformers.util import get_device_name class Transformer(nn.Module): @@ -114,6 +115,23 @@ def forward(self, features): if "token_type_ids" in features: trans_features["token_type_ids"] = features["token_type_ids"] + device = get_device_name() + curr_tokenize_len = features["input_ids"].shape + if ( + device == "hpu" + and curr_tokenize_len[1] > 4096 + and "attn_softmax_bf16" in features + and "reuse_cache" in features + and "use_flash_attention" in features + and "flash_attention_recompute" in features + and "flash_attention_causal_mask" in features + ): + trans_features["attn_softmax_bf16"] = features["attn_softmax_bf16"] + trans_features["reuse_cache"] = features["reuse_cache"] + trans_features["use_flash_attention"] = features["use_flash_attention"] + trans_features["flash_attention_recompute"] = features["flash_attention_recompute"] + trans_features["flash_attention_causal_mask"] = features["flash_attention_causal_mask"] + output_states = self.auto_model(**trans_features, return_dict=False) output_tokens = output_states[0]