Skip to content

Commit

Permalink
async requests enabled (#376)
Browse files Browse the repository at this point in the history
* async requests enabled

* aiohttp checks

* fmt
  • Loading branch information
michaelfeil authored Sep 24, 2024
1 parent ce4e27c commit 06fd1f4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 27 deletions.
3 changes: 2 additions & 1 deletion libs/infinity_emb/infinity_emb/_optional_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _raise_error(self) -> None:
CHECK_SENTENCE_TRANSFORMERS = OptionalImports("sentence_transformers", "torch")
CHECK_TRANSFORMERS = OptionalImports("transformers", "torch")
CHECK_TORCH = OptionalImports("torch.nn", "torch")
CHECK_REQUESTS = OptionalImports("requests", "server")
# CHECK_REQUESTS = OptionalImports("requests", "server")
CHECK_AIOHTTP = OptionalImports("aiohttp", "server")
CHECK_PIL = OptionalImports("PIL", "vision")
CHECK_SOUNDFILE = OptionalImports("soundfile", "audio")
CHECK_PYDANTIC = OptionalImports("pydantic", "server")
Expand Down
5 changes: 2 additions & 3 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ async def image_embed(
f"options are {self.model_worker.capabilities}."
)

items = await asyncio.to_thread(resolve_images, images)
items = await resolve_images(images)
embeddings, usage = await self._schedule(items)
return embeddings, usage

Expand Down Expand Up @@ -262,8 +262,7 @@ async def audio_embed(
f"options are {self.model_worker.capabilities}."
)

items = await asyncio.to_thread(
resolve_audios,
items = await resolve_audios(
audios,
getattr(self.model_worker._model, "sampling_rate", -42),
)
Expand Down
67 changes: 44 additions & 23 deletions libs/infinity_emb/infinity_emb/transformer/vision/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2023-now michaelfeil

import asyncio
import io
from typing import List, Union

from infinity_emb._optional_imports import CHECK_PIL, CHECK_REQUESTS, CHECK_SOUNDFILE
from infinity_emb._optional_imports import CHECK_AIOHTTP, CHECK_PIL, CHECK_SOUNDFILE
from infinity_emb.primitives import (
AudioCorruption,
AudioSingle,
Expand All @@ -13,11 +14,12 @@
ImageSingle,
)

if CHECK_AIOHTTP.is_available:
import aiohttp

if CHECK_PIL.is_available:
from PIL import Image # type: ignore

if CHECK_REQUESTS.is_available:
import requests # type: ignore
if CHECK_SOUNDFILE.is_available:
import soundfile as sf # type: ignore

Expand All @@ -27,17 +29,20 @@ def resolve_from_img_obj(img_obj: "ImageClassType") -> ImageSingle:
return ImageSingle(image=img_obj)


def resolve_from_img_url(img_url: str) -> ImageSingle:
async def resolve_from_img_url(
img_url: str, session: "aiohttp.ClientSession"
) -> ImageSingle:
"""Resolve an image from an URL."""
try:
downloaded_img = requests.get(img_url, stream=True).raw
# requests.get(img_url, stream=True).raw
downloaded_img = await (await session.get(img_url)).read()
except Exception as e:
raise ImageCorruption(
f"error opening an image in your request image from url: {e}"
)

try:
img = Image.open(downloaded_img)
img = Image.open(io.BytesIO(downloaded_img))
if img.size[0] < 3 or img.size[1] < 3:
# https://upload.wikimedia.org/wikipedia/commons/c/ca/1x1.png
raise ImageCorruption(
Expand All @@ -50,45 +55,57 @@ def resolve_from_img_url(img_url: str) -> ImageSingle:
)


def resolve_image(img: Union[str, "ImageClassType"]) -> ImageSingle:
async def resolve_image(
img: Union[str, "ImageClassType"], session: "aiohttp.ClientSession"
) -> ImageSingle:
"""Resolve a single image."""
if isinstance(img, Image.Image):
return resolve_from_img_obj(img)
elif isinstance(img, str):
return resolve_from_img_url(img)
return await resolve_from_img_url(img, session=session)
else:
raise ValueError(
f"Invalid image type: {img} is neither str nor ImageClassType object"
)


def resolve_images(images: List[Union[str, "ImageClassType"]]) -> List[ImageSingle]:
async def resolve_images(
images: List[Union[str, "ImageClassType"]]
) -> List[ImageSingle]:
"""Resolve images from URLs or ImageClassType Objects using multithreading."""
# TODO: improve parallel requests, safety, error handling
CHECK_REQUESTS.mark_required()
CHECK_AIOHTTP.mark_required()
CHECK_PIL.mark_required()

resolved_imgs = []
for img in images:
try:
resolved_imgs.append(resolve_image(img))
except Exception as e:
raise ImageCorruption(
f"Failed to resolve image: {img}.\nError msg: {str(e)}"

try:
async with aiohttp.ClientSession(trust_env=True) as session:
resolved_imgs = await asyncio.gather(
*[resolve_image(img, session) for img in images]
)
except Exception as e:
raise ImageCorruption(
f"Failed to resolve image: {images}.\nError msg: {str(e)}"
)

return resolved_imgs


def resolve_audio(audio: Union[str, bytes], allowed_sampling_rate: int) -> AudioSingle:
async def resolve_audio(
audio: Union[str, bytes],
allowed_sampling_rate: int,
session: "aiohttp.ClientSession",
) -> AudioSingle:
if isinstance(audio, bytes):
try:
audio_bytes = io.BytesIO(audio)
except Exception as e:
raise AudioCorruption(f"Error opening audio: {e}")
else:
try:
downloaded = requests.get(audio, stream=True).content
downloaded = await (await session.get(audio)).read()
# downloaded = requests.get(audio, stream=True).content
audio_bytes = io.BytesIO(downloaded)
except Exception as e:
raise AudioCorruption(f"Error downloading audio.\nError msg: {str(e)}")
Expand All @@ -104,18 +121,22 @@ def resolve_audio(audio: Union[str, bytes], allowed_sampling_rate: int) -> Audio
raise AudioCorruption(f"Error opening audio: {e}.\nError msg: {str(e)}")


def resolve_audios(
async def resolve_audios(
audio_urls: list[Union[str, bytes]], allowed_sampling_rate: int
) -> list[AudioSingle]:
"""Resolve audios from URLs."""
CHECK_REQUESTS.mark_required()
CHECK_AIOHTTP.mark_required()
CHECK_SOUNDFILE.mark_required()

resolved_audios: list[AudioSingle] = []
for audio in audio_urls:
async with aiohttp.ClientSession(trust_env=True) as session:
try:
audio_single = resolve_audio(audio, allowed_sampling_rate)
resolved_audios.append(audio_single)
resolved_audios = await asyncio.gather(
*[
resolve_audio(audio, allowed_sampling_rate, session)
for audio in audio_urls
]
)
except Exception as e:
raise AudioCorruption(f"Failed to resolve audio: {e}")

Expand Down

0 comments on commit 06fd1f4

Please sign in to comment.