From b86bd3da4947bb397ede2b3d5b49defe0138ed2d Mon Sep 17 00:00:00 2001 From: Tobias Schnabel Date: Tue, 27 Feb 2024 11:53:06 -0800 Subject: [PATCH] - using httpx to call APIs now --- sammo/runners.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/sammo/runners.py b/sammo/runners.py index e117b82..0b84906 100644 --- a/sammo/runners.py +++ b/sammo/runners.py @@ -9,8 +9,7 @@ import logging import os import pathlib - -import aiohttp +import httpx import orjson from beartype import beartype from beartype.typing import Literal @@ -279,9 +278,23 @@ async def _call_backend(self, request: dict) -> dict: return (await self._client.embeddings.create(**request, model=self._model_id)).model_dump() -class DeepInfraEmbedding(OpenAIEmbedding): +class DeepInfraEmbedding(BaseRunner): BASE_URL = r"https://api.deepinfra.com/v1/inference/" + def _post_init(self): + self._client = httpx.AsyncClient() + + def __del__(self): + # somewhat hacky way to close the client + try: + loop = asyncio.get_event_loop() + except RuntimeError: + return + if loop.is_running(): + loop.create_task(self._client.aclose()) + else: + loop.run_until_complete(self._client.aclose()) + async def generate_embedding(self, text: str | list[str], priority: int = 0) -> LLMResult: if isinstance(text, list) and len(text) > 2048: raise ValueError("Batch size must be below 2048.") @@ -292,17 +305,12 @@ async def generate_embedding(self, text: str | list[str], priority: int = 0) -> return await self._execute_request(request, fingerprint, priority) async def _call_backend(self, request: dict) -> dict: - # let parent class handle timeouts - timeout = aiohttp.ClientTimeout(total=None) - async with aiohttp.ClientSession(timeout=timeout) as session: - # todo: share session across multiple requests - async with session.post( - self.BASE_URL + self._model_id, - headers=dict(Authorization=f"Bearer {self._api_config['api_key']}"), - raise_for_status=True, - data=orjson.dumps(request), - ) as response: - return await response.json() + response = await self._client.post( + self.BASE_URL + self._model_id, + headers=dict(Authorization=f"Bearer {self._api_config['api_key']}"), + data=orjson.dumps(request), + ) + return response.raise_for_status().json() def _to_llm_result(self, request: dict, json_data: dict, fingerprint: str | bytes): return LLMResult(json_data["embeddings"], costs=Costs(json_data["input_tokens"]))