Skip to content

Commit

Permalink
- using httpx to call APIs now
Browse files Browse the repository at this point in the history
  • Loading branch information
sammym1982 committed Feb 27, 2024
1 parent 90486d7 commit b86bd3d
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions sammo/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import logging
import os
import pathlib

import aiohttp
import httpx
import orjson
from beartype import beartype
from beartype.typing import Literal
Expand Down Expand Up @@ -279,9 +278,23 @@ async def _call_backend(self, request: dict) -> dict:
return (await self._client.embeddings.create(**request, model=self._model_id)).model_dump()


class DeepInfraEmbedding(OpenAIEmbedding):
class DeepInfraEmbedding(BaseRunner):
BASE_URL = r"https://api.deepinfra.com/v1/inference/"

def _post_init(self):
self._client = httpx.AsyncClient()

def __del__(self):
# somewhat hacky way to close the client
try:
loop = asyncio.get_event_loop()
except RuntimeError:
return
if loop.is_running():
loop.create_task(self._client.aclose())
else:
loop.run_until_complete(self._client.aclose())

async def generate_embedding(self, text: str | list[str], priority: int = 0) -> LLMResult:
if isinstance(text, list) and len(text) > 2048:
raise ValueError("Batch size must be below 2048.")
Expand All @@ -292,17 +305,12 @@ async def generate_embedding(self, text: str | list[str], priority: int = 0) ->
return await self._execute_request(request, fingerprint, priority)

async def _call_backend(self, request: dict) -> dict:
# let parent class handle timeouts
timeout = aiohttp.ClientTimeout(total=None)
async with aiohttp.ClientSession(timeout=timeout) as session:
# todo: share session across multiple requests
async with session.post(
self.BASE_URL + self._model_id,
headers=dict(Authorization=f"Bearer {self._api_config['api_key']}"),
raise_for_status=True,
data=orjson.dumps(request),
) as response:
return await response.json()
response = await self._client.post(
self.BASE_URL + self._model_id,
headers=dict(Authorization=f"Bearer {self._api_config['api_key']}"),
data=orjson.dumps(request),
)
return response.raise_for_status().json()

def _to_llm_result(self, request: dict, json_data: dict, fingerprint: str | bytes):
return LLMResult(json_data["embeddings"], costs=Costs(json_data["input_tokens"]))

0 comments on commit b86bd3d

Please sign in to comment.