From e2c6379f8bd3c36b16990af4dffc06e2bf087b4b Mon Sep 17 00:00:00 2001 From: Jan Cuhel Date: Mon, 6 May 2024 08:16:27 +0200 Subject: [PATCH] Add SigLIP support --- sentence_transformers/SentenceTransformer.py | 11 ++- sentence_transformers/models/SigLIPModel.py | 81 ++++++++++++++++++++ sentence_transformers/models/__init__.py | 2 + 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 sentence_transformers/models/SigLIPModel.py diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 06c37a5f0..690129f01 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1198,8 +1198,17 @@ def _load_auto_model( local_files_only: bool = False, ): """ - Creates a simple Transformer + Mean Pooling model and returns the modules + Creates a simple Transformer + Mean Pooling model or a SigLIP model and returns the modules """ + if "siglip" in model_name_or_path: + from sentence_transformers.models import SigLIPModel + logger.warning( + "No sentence-transformers model found with name {}. Creating a SigLIPModel.".format( + model_name_or_path + ) + ) + return [SigLIPModel(model_name_or_path)] + logger.warning( "No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format( model_name_or_path diff --git a/sentence_transformers/models/SigLIPModel.py b/sentence_transformers/models/SigLIPModel.py new file mode 100644 index 000000000..e1414c985 --- /dev/null +++ b/sentence_transformers/models/SigLIPModel.py @@ -0,0 +1,81 @@ +from typing import Union +from torch import nn +import transformers +import torch +from PIL import Image + + +class SigLIPModel(nn.Module): + def __init__(self, model_name: str = "google/siglip-so400m-patch14-384", processor_name=None): + super(SigLIPModel, self).__init__() + + if processor_name is None: + processor_name = model_name + + self.model = transformers.AutoModel.from_pretrained(model_name) + self.processor = transformers.AutoProcessor.from_pretrained(processor_name) + + def __repr__(self): + return "SigLIPModel()" + + def forward(self, features): + image_embeds = [] + text_embeds = [] + + if "pixel_values" in features: + image_embeds = self.model.get_image_features(features["pixel_values"]) + + if "input_ids" in features: + text_embeds = self.model.get_text_features( + input_ids=features.get("input_ids"), + attention_mask=features.get("attention_mask", None), + position_ids=features.get("position_ids", None), + output_attentions=features.get("output_attentions", None), + output_hidden_states=features.get("output_hidden_states", None), + ) + + sentence_embedding = [] + image_features = iter(image_embeds) + text_features = iter(text_embeds) + + for idx, input_type in enumerate(features["image_text_info"]): + if input_type == 0: + sentence_embedding.append(next(image_features)) + else: + sentence_embedding.append(next(text_features)) + + features["sentence_embedding"] = torch.stack(sentence_embedding).float() + + return features + + def tokenize(self, texts, padding: Union[str, bool] = "max_length"): + images = [] + texts_values = [] + image_text_info = [] + + for idx, data in enumerate(texts): + if isinstance(data, Image.Image): # An Image + images.append(data) + image_text_info.append(0) + else: # A text + texts_values.append(data) + image_text_info.append(1) + + encoding = {} + if len(texts_values): + encoding = self.processor.tokenizer(texts_values, return_tensors="pt", padding=padding) + + if len(images): + image_features = self.processor.image_processor(images, return_tensors="pt") + encoding["pixel_values"] = image_features.pixel_values + + encoding["image_text_info"] = image_text_info + return encoding + + def save(self, output_path: str): + self.model.save_pretrained(output_path) + self.processor.save_pretrained(output_path) + + @staticmethod + def load(input_path: str): + return SigLIPModel(model_name=input_path) diff --git a/sentence_transformers/models/__init__.py b/sentence_transformers/models/__init__.py index a0a518ba4..909678bd1 100644 --- a/sentence_transformers/models/__init__.py +++ b/sentence_transformers/models/__init__.py @@ -12,6 +12,7 @@ from .WordEmbeddings import WordEmbeddings from .WordWeights import WordWeights from .CLIPModel import CLIPModel +from .SigLIPModel import SigLIPModel __all__ = [ "Transformer", @@ -28,4 +29,5 @@ "WordEmbeddings", "WordWeights", "CLIPModel", + "SigLIPModel" ]