Skip to content

Commit

Permalink
Feature: Add Support for Data URI (#385)
Browse files Browse the repository at this point in the history
* add support for data URI

* Remove unnecessary async

* Update utils.py

* commit additional changes

* fix merge

* remove unused regex code

---------

Co-authored-by: davleop <>
Co-authored-by: Michael Feil <[email protected]>
Co-authored-by: michaelfeil <[email protected]>
  • Loading branch information
3 people authored Oct 1, 2024
1 parent 5881a74 commit bb39cbd
Show file tree
Hide file tree
Showing 13 changed files with 472 additions and 90 deletions.
2 changes: 1 addition & 1 deletion docs/assets/openapi.json

Large diffs are not rendered by default.

79 changes: 37 additions & 42 deletions docs/docs/cli_v2.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def classify(
return scores, usage

async def image_embed(
self, *, images: List[Union[str, "ImageClassType"]]
self, *, images: List[Union[str, "ImageClassType", bytes]]
) -> tuple[list[EmbeddingReturnType], int]:
"""embed multiple images
Expand Down
239 changes: 239 additions & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/data_uri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import mimetypes
import re
import sys
import textwrap
from base64 import b64decode as decode64
from base64 import b64encode as encode64
from dataclasses import dataclass
from typing import Any, Dict, MutableMapping, Optional, Tuple, TypeVar, Union

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from urllib.parse import quote, unquote

T = TypeVar("T")

MIMETYPE_REGEX = r"[\w]+\/[\w\-\+\.]+"
MIMETYPE_REGEX_AUDIO_IMAGE = r"(audio|image)\/[\w\-\+\.]+"
_MIMETYPE_RE = re.compile("^{}$".format(MIMETYPE_REGEX_AUDIO_IMAGE))

CHARSET_REGEX = r"[\w\-\+\.]+"
_CHARSET_RE = re.compile("^{}$".format(CHARSET_REGEX))

DATA_URI_REGEX = (
r"data:"
+ r"(?P<mimetype>{})?".format(MIMETYPE_REGEX)
+ r"(?:\;name\=(?P<name>[\w\.\-%!*'~\(\)]+))?"
+ r"(?:\;charset\=(?P<charset>{}))?".format(CHARSET_REGEX)
+ r"(?P<base64>\;base64)?"
+ r",(?P<data>.*)"
)
_DATA_URI_RE = re.compile(r"^{}$".format(DATA_URI_REGEX), re.DOTALL)


class InvalidMimeType(ValueError):
pass


class InvalidCharset(ValueError):
pass


class InvalidDataURI(ValueError):
pass


@dataclass
class DataURIHolder:
mimetype: Optional[str]
charset: Optional[str]
base64: bool
data: Union[str, bytes]


class DataURI(str):
@classmethod
def make(
cls,
mimetype: Optional[str],
charset: Optional[str],
base64: Optional[bool],
data: Union[str, bytes],
) -> Self:
parts = ["data:"]
if mimetype is not None:
if not _MIMETYPE_RE.match(mimetype):
raise InvalidMimeType("Invalid mimetype: %r" % mimetype)
parts.append(mimetype)
if charset is not None:
if not _CHARSET_RE.match(charset):
raise InvalidCharset("Invalid charset: %r" % charset)
parts.extend([";charset=", charset])
if base64:
parts.append(";base64")
_charset = charset or "utf-8"
if isinstance(data, bytes):
_data = data
else:
_data = bytes(data, _charset)
encoded_data = encode64(_data).decode(_charset).strip()
else:
encoded_data = quote(data)
parts.extend([",", encoded_data])
return cls("".join(parts))

@classmethod
def from_file(
cls,
filename: str,
charset: Optional[str] = None,
base64: Optional[bool] = True,
mimetype: Optional[str] = None,
) -> Self:
if mimetype is None:
mimetype, _ = mimetypes.guess_type(filename, strict=False)
with open(filename, "rb") as fp:
data = fp.read()

return cls.make(mimetype, charset, base64, data)

def __new__(cls, *args: Any, **kwargs: Any) -> Self:
uri = super(DataURI, cls).__new__(cls, *args, **kwargs)
uri._parse # Trigger any ValueErrors on instantiation.
return uri

def __repr__(self) -> str:
truncated = str(self)
if len(truncated) > 80:
truncated = truncated[:79] + "…"
return "DataURI(%s)" % (truncated,)

def wrap(self, width: int = 76) -> str:
return "\n".join(textwrap.wrap(self, width, break_on_hyphens=False))

@property
def mimetype(self) -> Optional[str]:
return self._parse[0]

@property
def name(self) -> Optional[str]:
name = self._parse[1]
if name is not None:
return unquote(name)
return name

@property
def charset(self) -> Optional[str]:
return self._parse[2]

@property
def is_base64(self) -> bool:
return self._parse[3]

@property
def data(self) -> bytes:
return self._parse[4]

def convert_to_data_uri_holder(self) -> DataURIHolder:
return DataURIHolder(
mimetype=self.mimetype,
charset=self.charset,
base64=self.is_base64,
data=self.data,
)

@property
def text(self) -> str:
if self.charset is None:
raise InvalidCharset("DataURI has no encoding set.")

return self.data.decode(self.charset)

@property
def is_valid(self) -> bool:
match = _DATA_URI_RE.match(self)
if not match:
return False
return True

@property
def _parse(
self,
) -> Tuple[Optional[str], Optional[str], Optional[str], bool, bytes]:
match = _DATA_URI_RE.match(self)
if match is None:
raise InvalidDataURI("Not a valid data URI: %r" % self)
mimetype = match.group("mimetype") or None
name = match.group("name") or None
charset = match.group("charset") or None
_charset = charset or "utf-8"

if match.group("base64"):
_data = bytes(match.group("data"), _charset)
data = decode64(_data)
else:
data = bytes(unquote(match.group("data")), _charset)

return mimetype, name, charset, bool(match.group("base64")), data

# Pydantic methods
@classmethod
def __get_validators__(cls):
# one or more validators may be yielded which will be called in the
# order to validate the input, each validator will receive as an input
# the value returned from the previous validator
yield cls.validate

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> Any:
from pydantic_core import core_schema

# return core_schema.no_info_after_validator_function(cls, handler(str))
return core_schema.no_info_after_validator_function(
cls.validate, core_schema.str_schema()
)

@classmethod
def validate(
cls,
value: str,
values: Optional[MutableMapping[str, Any]] = None,
config: Any = None,
field: Any = None,
**kwargs: Any,
) -> Self:
if not isinstance(value, str):
raise TypeError("string required")

m = cls(value)
if not m.is_valid:
raise ValueError("invalid data-uri format")
return m

@classmethod
def __get_pydantic_json_schema__(
cls, schema: MutableMapping[str, Any], handler: Any
) -> Any:
json_schema = handler(schema)
json_schema.update(
pattern=DATA_URI_REGEX,
examples=[
"data:text/plain;charset=utf-8;base64,"
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wZWQgb3ZlciB0aGUgbGF6eSBkb2cu"
],
)
return json_schema

@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
# __modify_schema__ should mutate the dict it receives in place,
# the returned value will be ignored
field_schema.update(
pattern=DATA_URI_REGEX,
examples=[
"data:text/plain;charset=utf-8;base64,VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wZWQgb3ZlciB0aGUgbGF6eSBkb2cu"
],
)
22 changes: 22 additions & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pydantic_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pydantic import AnyUrl, HttpUrl, StringConstraints

__all__ = [
"INPUT_STRING",
"ITEMS_LIMIT",
"ITEMS_LIMIT_SMALL",
"AnyUrl",
"HttpUrl",
]

# Note: adding artificial limit, this might reveal splitting
# issues on the client side
# and is not a hard limit on the server side.
INPUT_STRING = StringConstraints(max_length=8192 * 15, strip_whitespace=True)
ITEMS_LIMIT = {
"min_length": 1,
"max_length": 2048,
}
ITEMS_LIMIT_SMALL = {
"min_length": 1,
"max_length": 32,
}
33 changes: 16 additions & 17 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@
from pydantic import BaseModel, Field, conlist

try:
from pydantic import AnyUrl, HttpUrl, StringConstraints

# Note: adding artificial limit, this might reveal splitting
# issues on the client side
# and is not a hard limit on the server side.
INPUT_STRING = StringConstraints(max_length=8192 * 15, strip_whitespace=True)
ITEMS_LIMIT = {
"min_length": 1,
"max_length": 2048,
}
ITEMS_LIMIT_SMALL = {
"min_length": 1,
"max_length": 32,
}
from .data_uri import DataURI
from .pydantic_v2 import (
INPUT_STRING,
ITEMS_LIMIT,
ITEMS_LIMIT_SMALL,
HttpUrl,
)
except ImportError:
from pydantic import constr

Expand All @@ -49,12 +42,18 @@
"min_items": 1,
"max_items": 32,
}
HttpUrl, AnyUrl = str, str # type: ignore
HttpUrl = str # type: ignore
DataURI = str # type: ignore
DataURIorURL = Union[Annotated[DataURI, str], HttpUrl]

else:

class BaseModel: # type: ignore[no-redef]
pass

class DataURI: # type: ignore
pass

def Field(*args, **kwargs): # type: ignore
pass

Expand Down Expand Up @@ -83,10 +82,10 @@ class OpenAIEmbeddingInput(BaseModel):
class ImageEmbeddingInput(BaseModel):
input: Union[ # type: ignore
conlist( # type: ignore
Annotated[AnyUrl, HttpUrl],
DataURIorURL,
**ITEMS_LIMIT_SMALL,
),
Annotated[AnyUrl, HttpUrl],
DataURIorURL,
]
model: str = "default/not-specified"
encoding_format: EmbeddingEncodingFormat = EmbeddingEncodingFormat.float
Expand Down
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def classify(
async def image_embed(
self,
*,
images: List[Union[str, "ImageClassType"]],
images: List[Union[str, "ImageClassType", bytes]],
) -> tuple[list[EmbeddingReturnType], int]:
"""Schedule a images and sentences to be embedded. Awaits until embedded.
Expand Down
Loading

0 comments on commit bb39cbd

Please sign in to comment.