Skip to content

Commit

Permalink
FEAT: Add support for CLIP model (#2637)
Browse files Browse the repository at this point in the history
Signed-off-by: Song Wei <[email protected]>
  • Loading branch information
Second222None authored Dec 12, 2024
1 parent 6b0bf6f commit 609b825
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip install -U silero-vad
${{ env.SELF_HOST_PYTHON }} -m pip install -U pydantic
${{ env.SELF_HOST_PYTHON }} -m pip install -U diffusers
${{ env.SELF_HOST_PYTHON }} -m pip install -U onnx
${{ env.SELF_HOST_PYTHON }} -m pip install -U onnxconverter_common
${{ env.SELF_HOST_PYTHON }} -m pip install -U torchdiffeq
${{ env.SELF_HOST_PYTHON }} -m pip install -U "x_transformers>=1.31.14"
${{ env.SELF_HOST_PYTHON }} -m pip install -U pypinyin
Expand Down
6 changes: 3 additions & 3 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ class Config:

class CreateEmbeddingRequest(BaseModel):
model: str
input: Union[str, List[str], List[int], List[List[int]]] = Field(
description="The input to embed."
)
input: Union[
str, List[str], List[int], List[List[int]], Dict[str, str], List[Dict[str, str]]
] = Field(description="The input to embed.")
user: Optional[str] = None

class Config:
Expand Down
50 changes: 47 additions & 3 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def to(self, *args, **kwargs):
trust_remote_code=True,
)

def _fix_langchain_openai_inputs(self, sentences: Union[str, List[str]]):
def _fix_langchain_openai_inputs(
self, sentences: Union[str, List[str], Dict[str, str], List[Dict[str, str]]]
):
# Check if sentences is a two-dimensional list of integers
if (
isinstance(sentences, list)
Expand Down Expand Up @@ -260,7 +262,11 @@ def _fix_langchain_openai_inputs(self, sentences: Union[str, List[str]]):
sentences = lines_decoded
return sentences

def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
def create_embedding(
self,
sentences: Union[str, List[str]],
**kwargs,
):
sentences = self._fix_langchain_openai_inputs(sentences)

from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -539,7 +545,11 @@ def encode(
features.update(extra_features)
# when batching, the attention mask 1 means there is a token
# thus we just sum up it to get the total number of tokens
all_token_nums += features["attention_mask"].sum().item()
if "clip" in self._model_spec.model_name.lower():
all_token_nums += features["input_ids"].numel()
all_token_nums += features["pixel_values"].numel()
else:
all_token_nums += features["attention_mask"].sum().item()

with torch.no_grad():
out_features = model.forward(features, **kwargs)
Expand Down Expand Up @@ -615,6 +625,40 @@ def encode(
all_embeddings, all_token_nums = _encode_bgem3(
self._model, sentences, convert_to_numpy=False, **kwargs
)
elif "clip" in self._model_spec.model_name.lower():
import base64
import re
from io import BytesIO

from PIL import Image

def base64_to_image(base64_str: str) -> Image.Image:
# base64_data = re.sub("^data:image/.+;base64,", "", base64_str)
base64_data = base64_str.split(",", 1)[1]
byte_data = base64.b64decode(base64_data)
image_data = BytesIO(byte_data)
img = Image.open(image_data)
return img

objs: list[dict[str, str]] = []
for item in sentences:
if isinstance(item, dict):
if item.get("text") is not None:
objs.append(item["text"])
elif item.get("image") is not None:
if re.match(r"^data:image/.+;base64,", item["image"]):
image = base64_to_image(item["image"])
objs.append(image)
else:
objs.append(item["image"])
else:
logger.error("Please check the input data.")
all_embeddings, all_token_nums = encode(
self._model,
objs,
convert_to_numpy=False,
**self._kwargs,
)
else:
all_embeddings, all_token_nums = encode(
self._model,
Expand Down
7 changes: 7 additions & 0 deletions xinference/model/embedding/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,12 @@
"max_tokens": 8192,
"language": ["zh", "en"],
"model_id": "jinaai/jina-embeddings-v3"
},
{
"model_name": "jina-clip-v2",
"dimensions": 1024,
"max_tokens": 8192,
"language": ["89 languages supported"],
"model_id": "jinaai/jina-clip-v2"
}
]
8 changes: 8 additions & 0 deletions xinference/model/embedding/model_spec_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,13 @@
"language": ["zh", "en"],
"model_id": "jinaai/jina-embeddings-v3",
"model_hub": "modelscope"
},
{
"model_name": "jina-clip-v2",
"dimensions": 1024,
"max_tokens": 8192,
"language": ["89 languages supported"],
"model_id": "jinaai/jina-clip-v2",
"model_hub": "modelscope"
}
]
42 changes: 42 additions & 0 deletions xinference/model/embedding/tests/test_integrated_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
from io import BytesIO

import numpy as np
import requests
from PIL import Image

from ....client import Client


Expand All @@ -36,3 +43,38 @@ def test_sparse_embedding(setup):
words = model.convert_ids_to_tokens(token_ids)
assert len(words) == len(token_ids)
assert isinstance(words[0], str)


def test_clip_embedding(setup):
endpoint, _ = setup
client = Client(endpoint)

model_uid = client.launch_model(
model_name="jina-clip-v2", model_type="embedding", torch_dtype="float16"
)
assert len(client.list_models()) == 1

model = client.get_model(model_uid)

def image_to_base64(image: Image.Image, fmt="png") -> str:
output_buffer = BytesIO()
image.save(output_buffer, format=fmt)
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
return f"data:image/{fmt};base64," + base64_str

image_str = "https://i.ibb.co/r5w8hG8/beach2.jpg"
image_str_base64 = image_to_base64(
Image.open(BytesIO(requests.get(image_str).content))
)
input = [
{"text": "This is a picture of diagram"},
{"image": image_str_base64},
{"text": "a dog"},
{"image": image_str},
{"text": "海滩上美丽的日落。"},
]
response = model.create_embedding(input)
for i in range(len(response["data"])):
embedding = np.array([item for item in response["data"][i]["embedding"]])
assert embedding.shape == (1024,)

0 comments on commit 609b825

Please sign in to comment.