From 9aeaf1436f4a6957dd1e3f93c9741eecb223a66e Mon Sep 17 00:00:00 2001 From: Tobias Lohse Date: Tue, 1 Oct 2024 11:26:53 -0500 Subject: [PATCH] Add truncation to CLIP model --- sentence_transformers/models/CLIPModel.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/models/CLIPModel.py b/sentence_transformers/models/CLIPModel.py index 06408b3e7..6d1d446eb 100644 --- a/sentence_transformers/models/CLIPModel.py +++ b/sentence_transformers/models/CLIPModel.py @@ -9,7 +9,7 @@ class CLIPModel(nn.Module): save_in_root: bool = True - def __init__(self, model_name: str = "openai/clip-vit-base-patch32", processor_name=None) -> None: + def __init__(self, model_name: str = "openai/clip-vit-base-patch32", processor_name=None, max_seq_length: int | None = None) -> None: super().__init__() if processor_name is None: @@ -17,6 +17,7 @@ def __init__(self, model_name: str = "openai/clip-vit-base-patch32", processor_n self.model = transformers.CLIPModel.from_pretrained(model_name) self.processor = transformers.CLIPProcessor.from_pretrained(processor_name) + self.max_seq_length = max_seq_length or self.processor.tokenizer.model_max_length def __repr__(self) -> str: return "CLIPModel()" @@ -68,7 +69,13 @@ def tokenize(self, texts, padding: str | bool = True) -> dict[str, torch.Tensor] encoding = {} if len(texts_values): - encoding = self.processor.tokenizer(texts_values, return_tensors="pt", padding=padding) + encoding = self.processor.tokenizer( + texts_values, + truncation="longest_first", + return_tensors="pt", + padding=padding, + max_length=self.max_seq_length, + ) if len(images): image_features = self.processor.image_processor(images, return_tensors="pt")