Skip to content

Commit

Permalink
add weakref to sync engine (#373)
Browse files Browse the repository at this point in the history
* add weakref

* fmt

* poetry update lock package

* add error handling

* adjust permalink
  • Loading branch information
michaelfeil authored Sep 24, 2024
1 parent a70525a commit a00374c
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 14 deletions.
7 changes: 2 additions & 5 deletions libs/embed_package/embed/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,9 @@ def audio_embed(
"""Embed audios with a model.
>>> import requests, io
>>> import soundfile as sf
>>> url = "https://bigsoundbank.com/UPLOAD/wav/2380.wav"
>>> raw_bytes = requests.get(url, stream=True).content
>>> data, samplerate = sf.read(io.BytesIO(raw_bytes))
>>> url = "https://github.com/michaelfeil/infinity/raw/refs/heads/main/libs/infinity_emb/tests/data/audio/COMTran_Aerospacebeep1(ID2380)_BSB.wav"
>>> ei = BatchedInference(model_id="laion/larger_clap_general", engine="torch")
>>> audio_embed_result = ei.audio_embed(model_id="laion/larger_clap_general", audios=[data])
>>> audio_embed_result = ei.audio_embed(model_id="laion/larger_clap_general", audios=[url])
>>> type(audio_embed_result)
<class 'concurrent.futures._base.Future'>
>>> audio_embed_result.result()[0][0].shape
Expand Down
26 changes: 25 additions & 1 deletion libs/embed_package/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions libs/embed_package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ priority = "explicit"

[tool.poetry.dependencies]
python = ">=3.9,<4"
infinity_emb = {path = "../infinity_emb", extras = ["optimum","vision","torch"]}
# infinity_emb = {version = "0.0.57", extras = ["optimum","vision","torch"]}
infinity_emb = {path = "../infinity_emb", extras = ["optimum","vision","torch","audio"]}
# infinity_emb = {version = "0.0.57", extras = ["optimum","vision","torch","audio"]}
[tool.poetry.group.test.dependencies]
pytest = "^7.0.0"
coverage = {extras = ["toml"], version = "^7.3.2"}
Expand Down
9 changes: 9 additions & 0 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ async def _collect_from_model(
shutdown: ShutdownReadOnly, result_queue: Queue, tp: ThreadPoolExecutor
):
"""background thread for reading exits only if shutdown.is_set()"""
schedule_errors = 0
try:
while not shutdown.is_set():
try:
Expand All @@ -353,6 +354,14 @@ async def _collect_from_model(
except queue.Empty:
# in case of timeout start again
continue
except Exception as e:
# exception handing without loop forever.
time.sleep(1)
schedule_errors += 1
if schedule_errors > 10:
logger.error("too many schedule errors")
raise e
continue
results, batch = post_batch
for i, item in enumerate(batch):
await item.complete(results[i])
Expand Down
19 changes: 16 additions & 3 deletions libs/infinity_emb/infinity_emb/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import threading
import time
import weakref
from concurrent.futures import Future
from functools import partial
Expand Down Expand Up @@ -56,7 +57,7 @@ async def block_until_engine_stop():
logger.info("Started Background Event Loop")
start_event.set_result(None) # signal that the event loop has started
while not self.__stop_signal.is_set():
await asyncio.sleep(0.5)
await asyncio.sleep(0.1)

self.__loop.run_until_complete(block_until_engine_stop())
self.__loop.close()
Expand Down Expand Up @@ -145,8 +146,17 @@ def __init__(self, _engine_args_array: list[EngineArgs]):
self.async_engine_array = AsyncEngineArray.from_args(_engine_args_array)
self.async_run(self.async_engine_array.astart).result()

# finalizer
finalize_fn = partial(self.async_run, self.async_engine_array.astop)
# finalizer to stop the engine
engine_ref = weakref.ref(self.async_engine_array)
async_run_ref = weakref.ref(self.async_run)

def finalize_fn():
engine = engine_ref()
run_ref = async_run_ref()
if engine is not None:
run_ref(engine.astop).result()
time.sleep(1.5) # wait for maximum of 1.5 seconds

weakref.finalize(self.async_engine_array, finalize_fn)

@classmethod
Expand Down Expand Up @@ -207,3 +217,6 @@ def audio_embed(self, *, model: str, audios: list[npt.NDArray]):
return self.async_run(
self.async_engine_array.audio_embed, model=model, audios=audios
)

def __del__(self):
self.stop()
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/transformer/vision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def resolve_audios(
audio_single = resolve_audio(audio)
resolved_audios.append(audio_single)
except Exception as e:
raise AudioCorruption(f"Failed to resolve image: {e}")
raise AudioCorruption(f"Failed to resolve audio: {e}")
if not (
all(
resolved_audios[0].sampling_rate == audio.sampling_rate
Expand Down
2 changes: 1 addition & 1 deletion libs/infinity_emb/tests/unit_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ async def test_clap_like_model():
engine = AsyncEmbeddingEngine.from_args(
EngineArgs(model_name_or_path=model_name, dtype="float32")
)
url = "https://github.com/wirthual/infinity/raw/b849258a5d60ba79f1c600cbca9c4ea77349876d/libs/infinity_emb/tests/data/audio/COMTran_Aerospacebeep1(ID2380)_BSB.wav"
url = "https://github.com/michaelfeil/infinity/raw/refs/heads/main/libs/infinity_emb/tests/data/audio/COMTran_Aerospacebeep1(ID2380)_BSB.wav"
bytes_url = requests.get(url).content

inputs = ["a sound of a cat", "a sound of a cat"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_clap_like_model():
model = ClapLikeModel(
engine_args=EngineArgs(model_name_or_path=model_name, dtype="float16")
)
url = "https://github.com/wirthual/infinity/raw/b849258a5d60ba79f1c600cbca9c4ea77349876d/libs/infinity_emb/tests/data/audio/COMTran_Aerospacebeep1(ID2380)_BSB.wav"
url = "https://github.com/michaelfeil/infinity/raw/refs/heads/main/libs/infinity_emb/tests/data/audio/COMTran_Aerospacebeep1(ID2380)_BSB.wav"
raw_bytes = requests.get(url, stream=True).content
data, samplerate = sf.read(io.BytesIO(raw_bytes))

Expand Down

0 comments on commit a00374c

Please sign in to comment.