From 9ddf2d5b47ef66f88906d31eb875282355c4a87d Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 18 Oct 2024 10:59:01 +0200 Subject: [PATCH] If AutoModel is wrapped with PEFT for prompt learning, then extend the attention mask --- sentence_transformers/models/Transformer.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 7592278bf..c1eef20c7 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -11,6 +11,7 @@ import torch from torch import nn from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config +from transformers.utils import is_peft_available logger = logging.getLogger(__name__) @@ -350,7 +351,23 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc output_states = self.auto_model(**trans_features, **kwargs, return_dict=False) output_tokens = output_states[0] - features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]}) + # If the AutoModel is wrapped with a PeftModelForFeatureExtraction, then it may have added virtual tokens + # We need to extend the attention mask to include these virtual tokens, or the pooling will fail + if is_peft_available(): + from peft import PeftModelForFeatureExtraction + + if ( + isinstance(self.auto_model, PeftModelForFeatureExtraction) + and self.auto_model.active_peft_config.is_prompt_learning + ): + batch_size = output_tokens.size(0) + attention_mask = features["attention_mask"] + prefix_attention_mask = torch.ones( + batch_size, self.auto_model.active_peft_config.num_virtual_tokens, device=attention_mask.device + ) + features["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + features["token_embeddings"] = output_tokens if self.auto_model.config.output_hidden_states: all_layer_idx = 2 @@ -358,7 +375,7 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc all_layer_idx = 1 hidden_states = output_states[all_layer_idx] - features.update({"all_layer_embeddings": hidden_states}) + features["all_layer_embeddings"] = hidden_states return features