Skip to content

Commit

Permalink
Merge pull request #282 from michaelfeil/sync-engine-refactor
Browse files Browse the repository at this point in the history
refactor sync engine
  • Loading branch information
michaelfeil authored Jun 23, 2024
2 parents 6e6d89c + 778c0e8 commit 9bf5a5f
Showing 1 changed file with 95 additions and 74 deletions.
169 changes: 95 additions & 74 deletions libs/infinity_emb/infinity_emb/sync_engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import threading
import time
from concurrent.futures import Future
from typing import Iterator
from functools import partial
from typing import TYPE_CHECKING, Awaitable, Callable, Iterator, TypeVar

from infinity_emb.engine import AsyncEmbeddingEngine, AsyncEngineArray, EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.primitives import ClassifyReturnType, EmbeddingDtype, ReRankReturnType

if TYPE_CHECKING:
from infinity_emb import AsyncEmbeddingEngine


def add_start_docstrings(*docstr):
Expand All @@ -17,110 +19,129 @@ def docstring_decorator(fn):
return docstring_decorator


def threaded_asyncio_executor():
def decorator(fn):
funcname = fn.__name__ # e.g. `embed`
T = TypeVar("T")

def wrapper(self: "SyncEngineArray", **kwargs) -> "Future":
future: Future = Future()

assert self.is_running, "SyncEngineArray is not running"
class AsyncLifeMixin:
def __init__(self) -> None:
self.__lock = threading.Lock()
self.__stop_signal = threading.Event()
self.__loop: asyncio.AbstractEventLoop = None # type: ignore
# init
self.__is_closed: Future = Future()
self.__is_closed.set_result(None)
self.async_start_loop()

def execute():
async_function = getattr(self.async_engine_array, funcname)
try:
# async_function is e.g. `self.async_engine_array.embed`
# get async future object
result = asyncio.run_coroutine_threadsafe(
async_function(**kwargs), self._loop
)
# block until the result is available
future.set_result(result.result())
except Exception as e:
future.set_exception(e)
def __async_lifetime(self, start_event: Future):
"""private function, takes care of starting, stopping event loop"""
self.__loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.__loop)

threading.Thread(target=execute).start()
return future # return the future object immediately
async def block_until_engine_stop():
logger.info("Started Background Event Loop")
start_event.set_result(None) # signal that the event loop has started
while not self.__stop_signal.is_set():
await asyncio.sleep(0.2)

wrapper.__doc__ = fn.__doc__
return wrapper
self.__loop.run_until_complete(block_until_engine_stop())
self.__loop.close()
self.__is_closed.set_result(None)
logger.info("Closed Background Event Loop")

return decorator
def is_async_loop_running(self):
return (
(not self.__stop_signal.is_set())
and self.__loop is not None
and self.__loop.is_running()
)

def async_start_loop(self):
self.async_close_loop()
with self.__lock:
start_event: Future = Future()
self.__stop_signal.clear()
self.__is_closed: Future = Future()
threading.Thread(
target=partial(self.__async_lifetime, start_event=start_event),
daemon=True,
).start()
start_event.result()

def async_close_loop(self):
"""closes the event loop. This is a blocking call"""
with self.__lock:
self.__stop_signal.set()
self.__is_closed.result()

def async_run(
self,
async_function: Callable[..., Awaitable[T]],
*funcion_args,
**function_kwargs
) -> Future[T]:
"""run an async function in the background event loop.
Args:
async_function: the async function to run
funcion_args: args to pass to the async function
function_kwargs: kwargs to pass to the async function
Returns:
concurrent.futures.Future returning the result of async_function.
"""
if not self.is_async_loop_running():
raise RuntimeError("Event loop is not running")
future = asyncio.run_coroutine_threadsafe(
async_function(*funcion_args, **function_kwargs), self.__loop
)
return future


@add_start_docstrings(AsyncEngineArray.__doc__)
class SyncEngineArray:
class SyncEngineArray(AsyncLifeMixin):
def __init__(self, engine_args: list[EngineArgs]):
self._start_event = threading.Event()
self._stop_event = threading.Event()
super().__init__()
self.async_engine_array = AsyncEngineArray.from_args(engine_args)
threading.Thread(target=self._lifetime).start()
self._start_event.wait() # wait until the event loop has started
self.async_run(self.async_engine_array.astart).result()

@classmethod
def from_args(cls, engine_args: list[EngineArgs]) -> "SyncEngineArray":
return cls(engine_args)

@property
def is_running(self):
return (
not self._stop_event.is_set()
and self._loop.is_running()
and self.async_engine_array.is_running
)
return self.async_engine_array.is_running

def __iter__(self) -> Iterator["AsyncEmbeddingEngine"]:
return iter(self.async_engine_array)

def stop(self):
"""blocks until the engine is stopped"""
self._stop_event.set()
while self._loop.is_running():
time.sleep(0.05)

def _lifetime(self):
"""takes care of starting, stopping (engine and event loop)"""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

async def block_until_engine_stop():
logger.info("Started SyncEngineArray Background Event Loop")
self._start_event.set() # signal that the event loop has started
try:
await self.async_engine_array.astart()
while not self._stop_event.is_set():
await asyncio.sleep(0.2)
finally:
await self.async_engine_array.astop()
# additional delay to ensure that the engine is stopped
await asyncio.sleep(2.0)

self._loop.run_until_complete(block_until_engine_stop())
self._loop.close()
logger.info("Closed SyncEngineArray Background Event Loop")
self.async_run(self.async_engine_array.astop).result()
self.async_close_loop()

@add_start_docstrings(AsyncEngineArray.embed.__doc__)
@threaded_asyncio_executor()
def embed(self, *, model: str, sentences: list[str]) -> Future[EmbeddingDtype]:
def embed(self, *, model: str, sentences: list[str]):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(
self.async_engine_array.embed, model=model, sentences=sentences
)

@add_start_docstrings(AsyncEngineArray.rerank.__doc__)
@threaded_asyncio_executor()
def rerank(
self, *, model: str, query: str, docs: list[str]
) -> Future[ReRankReturnType]:
def rerank(self, *, model: str, query: str, docs: list[str]):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(
self.async_engine_array.rerank, model=model, query=query, docs=docs
)

@add_start_docstrings(AsyncEngineArray.classify.__doc__)
@threaded_asyncio_executor()
def classify(self, *, model: str, text: str) -> Future[ClassifyReturnType]:
def classify(self, *, model: str, text: str):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(self.async_engine_array.classify, model=model, text=text)

@add_start_docstrings(AsyncEngineArray.image_embed.__doc__)
@threaded_asyncio_executor()
def image_embed(self, *, model: str, images: list[str]) -> Future[EmbeddingDtype]:
def image_embed(self, *, model: str, images: list[str]):
"""sync interface of AsyncEngineArray"""
return None # type: ignore
return self.async_run(
self.async_engine_array.image_embed, model=model, images=images
)

0 comments on commit 9bf5a5f

Please sign in to comment.