Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SigLIP support #2629

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,8 +1198,17 @@ def _load_auto_model(
local_files_only: bool = False,
):
"""
Creates a simple Transformer + Mean Pooling model and returns the modules
Creates a simple Transformer + Mean Pooling model or a SigLIP model and returns the modules
"""
if "siglip" in model_name_or_path:
from sentence_transformers.models import SigLIPModel
logger.warning(
"No sentence-transformers model found with name {}. Creating a SigLIPModel.".format(
model_name_or_path
)
)
return [SigLIPModel(model_name_or_path)]

logger.warning(
"No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(
model_name_or_path
Expand Down
81 changes: 81 additions & 0 deletions sentence_transformers/models/SigLIPModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Union
from torch import nn
import transformers
import torch
from PIL import Image


class SigLIPModel(nn.Module):
def __init__(self, model_name: str = "google/siglip-so400m-patch14-384", processor_name=None):
super(SigLIPModel, self).__init__()

if processor_name is None:
processor_name = model_name

self.model = transformers.AutoModel.from_pretrained(model_name)
self.processor = transformers.AutoProcessor.from_pretrained(processor_name)

def __repr__(self):
return "SigLIPModel()"

def forward(self, features):
image_embeds = []
text_embeds = []

if "pixel_values" in features:
image_embeds = self.model.get_image_features(features["pixel_values"])

if "input_ids" in features:
text_embeds = self.model.get_text_features(
input_ids=features.get("input_ids"),
attention_mask=features.get("attention_mask", None),
position_ids=features.get("position_ids", None),
output_attentions=features.get("output_attentions", None),
output_hidden_states=features.get("output_hidden_states", None),
)

sentence_embedding = []
image_features = iter(image_embeds)
text_features = iter(text_embeds)

for idx, input_type in enumerate(features["image_text_info"]):
if input_type == 0:
sentence_embedding.append(next(image_features))
else:
sentence_embedding.append(next(text_features))

features["sentence_embedding"] = torch.stack(sentence_embedding).float()

return features

def tokenize(self, texts, padding: Union[str, bool] = "max_length"):
images = []
texts_values = []
image_text_info = []

for idx, data in enumerate(texts):
if isinstance(data, Image.Image): # An Image
images.append(data)
image_text_info.append(0)
else: # A text
texts_values.append(data)
image_text_info.append(1)

encoding = {}
if len(texts_values):
encoding = self.processor.tokenizer(texts_values, return_tensors="pt", padding=padding)

if len(images):
image_features = self.processor.image_processor(images, return_tensors="pt")
encoding["pixel_values"] = image_features.pixel_values

encoding["image_text_info"] = image_text_info
return encoding

def save(self, output_path: str):
self.model.save_pretrained(output_path)
self.processor.save_pretrained(output_path)

@staticmethod
def load(input_path: str):
return SigLIPModel(model_name=input_path)
2 changes: 2 additions & 0 deletions sentence_transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .WordEmbeddings import WordEmbeddings
from .WordWeights import WordWeights
from .CLIPModel import CLIPModel
from .SigLIPModel import SigLIPModel

__all__ = [
"Transformer",
Expand All @@ -28,4 +29,5 @@
"WordEmbeddings",
"WordWeights",
"CLIPModel",
"SigLIPModel"
]