Skip to content

Commit

Permalink
feat: httpx timeout error handling (#139)
Browse files Browse the repository at this point in the history
Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
leoguillaume and leoguillaumegouv authored Jan 9, 2025
1 parent ee37cf1 commit 588e385
Show file tree
Hide file tree
Showing 14 changed files with 360 additions and 120 deletions.
4 changes: 1 addition & 3 deletions app/clients/_modelclients.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def create_embeddings(self, *args, **kwargs):


class ModelClient(OpenAI):
DEFAULT_TIMEOUT = 120

def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE], *args, **kwargs) -> None:
"""
ModelClient class extends AsyncOpenAI class to support custom methods.
Expand Down Expand Up @@ -162,7 +160,7 @@ def create(self, prompt: str, input: list[str], model: str) -> List[Rerank]:

return data

self.rerank = RerankClient(model=self.id, base_url=self.base_url, api_key=self.api_key, timeout=self.DEFAULT_TIMEOUT)
self.rerank = RerankClient(model=self.id, base_url=self.base_url, api_key=self.api_key, timeout=DEFAULT_TIMEOUT)


class ModelClients(dict):
Expand Down
37 changes: 16 additions & 21 deletions app/endpoints/audio.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import json
from typing import List, Literal

from fastapi import APIRouter, File, Form, HTTPException, Request, Security, UploadFile
from fastapi import APIRouter, File, Form, Request, Security, UploadFile
from fastapi.responses import PlainTextResponse
import httpx

from app.schemas.audio import AudioTranscription
from app.utils.exceptions import ModelNotFoundException
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.route import forward_request
from app.utils.security import User, check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT, AUDIO_MODEL_TYPE
from app.utils.variables import AUDIO_MODEL_TYPE, DEFAULT_TIMEOUT

router = APIRouter()

Expand Down Expand Up @@ -152,7 +151,7 @@ async def audio_transcriptions(
client = clients.models[model]

if client.type != AUDIO_MODEL_TYPE:
raise ModelNotFoundException()
raise WrongModelTypeException()

# @TODO: Implement prompt
# @TODO: Implement timestamp_granularities
Expand All @@ -163,20 +162,16 @@ async def audio_transcriptions(
url = f"{client.base_url}audio/transcriptions"
headers = {"Authorization": f"Bearer {client.api_key}"}

try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.post(
url=url,
headers=headers,
files={"file": (file.filename, file_content, file.content_type)},
data={"language": language, "response_format": response_format, "temperature": temperature},
)
response.raise_for_status()
if response_format == "text":
return PlainTextResponse(content=response.text)
response = await forward_request(
url=url,
method="POST",
headers=headers,
timeout=DEFAULT_TIMEOUT,
files={"file": (file.filename, file_content, file.content_type)},
data={"language": language, "response_format": response_format, "temperature": temperature},
)

data = response.json()
return AudioTranscription(**data)
if response_format == "text":
return PlainTextResponse(content=response.text)

except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(s=e.response.text)["message"])
return AudioTranscription(**response.json())
69 changes: 27 additions & 42 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import json
from typing import List, Tuple, Union

from fastapi import APIRouter, HTTPException, Request, Security
from fastapi import APIRouter, Request, Security
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import StreamingResponse
import httpx

from app.helpers import ClientsManager, InternetManager, SearchManager
from app.helpers import ClientsManager, InternetManager, SearchManager, StreamingResponseWithStatusCode
from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest
from app.schemas.search import Search
from app.schemas.security import User
from app.schemas.settings import Settings
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.route import forward_request, forward_stream
from app.utils.security import check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE
Expand All @@ -32,10 +30,12 @@ async def chat_completions(
client = clients.models[body.model]
if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()

body.model = client.id # replace alias by model id
url = f"{client.base_url}chat/completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

# retrieval augmentation generation
def retrieval_augmentation_generation(
body: ChatCompletionRequest, clients: ClientsManager, settings: Settings
) -> Tuple[ChatCompletionRequest, List[Search]]:
Expand Down Expand Up @@ -75,42 +75,27 @@ def retrieval_augmentation_generation(

# not stream case
if not body["stream"]:
try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body)
response.raise_for_status()
data = response.json()
data["search_results"] = searches

return ChatCompletion(**data)
except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])
response = await forward_request(
url=url,
method="POST",
headers=headers,
json=body,
timeout=DEFAULT_TIMEOUT,
additional_data_value=searches,
additional_data_key="search_results",
)
return ChatCompletion(**response.json())

# stream case
async def forward_stream(url: str, headers: dict, request: dict):
try:
error = None
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
async with async_client.stream(method="POST", url=url, headers=headers, json=request) as response:
if response.status_code >= 400:
error = await response.aread().decode()
response.raise_for_status()

i = 0
async for chunk in response.aiter_raw():
if i == 0:
chunks = chunk.decode(encoding="utf-8").split(sep="\n\n")
chunk = json.loads(chunks[0].lstrip("data: "))
chunk["search_results"] = searches
chunks[0] = f"data: {json.dumps(chunk)}"
chunk = "\n\n".join(chunks).encode(encoding="utf-8")
i = 1
yield chunk

# @TODO: raise the error instead of forwarding it (raise model )
except Exception as e:
error = error if error else {"error": {"type": e.__class__.__name__, "message": str(e), "code": 500}}
yield f"data: {json.dumps(error)}\n\n".encode(encoding="utf-8")
yield b"data: [DONE]\n\n"

return StreamingResponse(content=forward_stream(url=url, headers=headers, request=body), media_type="text/event-stream")
return StreamingResponseWithStatusCode(
content=forward_stream(
url=url,
method="POST",
headers=headers,
json=body,
timeout=DEFAULT_TIMEOUT,
additional_data_value=searches,
additional_data_key="search_results",
),
media_type="text/event-stream",
)
19 changes: 5 additions & 14 deletions app/endpoints/completions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json

from fastapi import APIRouter, HTTPException, Request, Security
import httpx
from fastapi import APIRouter, Request, Security

from app.schemas.completions import CompletionRequest, Completions
from app.schemas.security import User
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.route import forward_request
from app.utils.security import check_api_key, check_rate_limit
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE
Expand All @@ -24,17 +22,10 @@ async def completions(request: Request, body: CompletionRequest, user: User = Se
client = clients.models[body.model]
if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()

body.model = client.id # replace alias by model id
url = f"{client.base_url}completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()

data = response.json()
return Completions(**data)

except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])
response = await forward_request(url=url, method="POST", headers=headers, json=body.model_dump(), timeout=DEFAULT_TIMEOUT)
return Completions(**response.json())
20 changes: 7 additions & 13 deletions app/endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from fastapi import APIRouter, Request, Security, HTTPException
import httpx
import json
from fastapi import APIRouter, Request, Security

from app.schemas.embeddings import Embeddings, EmbeddingsRequest
from app.schemas.security import User
from app.utils.settings import settings
from app.utils.exceptions import WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.route import forward_request
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import EMBEDDINGS_MODEL_TYPE, DEFAULT_TIMEOUT
from app.utils.settings import settings
from app.utils.variables import DEFAULT_TIMEOUT, EMBEDDINGS_MODEL_TYPE

router = APIRouter()

Expand All @@ -24,15 +23,10 @@ async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Sec
client = clients.models[body.model]
if client.type != EMBEDDINGS_MODEL_TYPE:
raise WrongModelTypeException()

body.model = client.id # replace alias by model id
url = f"{client.base_url}embeddings"
headers = {"Authorization": f"Bearer {client.api_key}"}

try:
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()
data = response.json()
return Embeddings(**data)
except Exception as e:
raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"])
response = await forward_request(url=url, method="POST", headers=headers, json=body.model_dump(), timeout=DEFAULT_TIMEOUT)
return Embeddings(**response.json())
11 changes: 10 additions & 1 deletion app/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,14 @@
from ._languagemodelreranker import LanguageModelReranker
from ._metricsmiddleware import MetricsMiddleware
from ._searchmanager import SearchManager
from ._streamingresponsewithstatuscode import StreamingResponseWithStatusCode

__all__ = ["ClientsManager", "FileUploader", "LanguageModelReranker", "InternetManager", "MetricsMiddleware", "SearchManager"]
__all__ = [
"ClientsManager",
"FileUploader",
"LanguageModelReranker",
"InternetManager",
"MetricsMiddleware",
"SearchManager",
"StreamingResponseWithStatusCode",
]
19 changes: 19 additions & 0 deletions app/helpers/_metricsmiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from starlette.middleware.base import BaseHTTPMiddleware

from app.clients import AuthenticationClient
from app.utils.logging import logger
from app.utils.logging import client_ip


class MetricsMiddleware(BaseHTTPMiddleware):
Expand All @@ -16,10 +18,27 @@ class MetricsMiddleware(BaseHTTPMiddleware):
labelnames=["user", "endpoint", "model"],
)

async def __call__(self, scope, receive, send):
try:
await super().__call__(scope, receive, send)
except RuntimeError as exc:
# ignore the error when the request is disconnected by the client
if str(exc) == "No response returned.":
logger.info(
f'"{list(scope["route"].methods)[0]} {scope["route"].path} HTTP/{scope["http_version"]}" request disconnected by the client'
)
request = Request(scope, receive=receive)
if await request.is_disconnected():
return
raise

async def dispatch(self, request: Request, call_next) -> Response:
endpoint = request.url.path
content_type = request.headers.get("Content-Type", "")

client_addr = request.client.host
client_ip.set(client_addr)

if endpoint.startswith("/v1"):
authorization = request.headers.get("Authorization")
model = None
Expand Down
62 changes: 62 additions & 0 deletions app/helpers/_streamingresponsewithstatuscode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json
from typing import AsyncIterator

from fastapi.responses import StreamingResponse
from starlette.types import Send


class StreamingResponseWithStatusCode(StreamingResponse):
"""
Variation of StreamingResponse that can dynamically decide the HTTP status code,
based on the return value of the content iterator (parameter `content`).
Expects the content to yield either just str content as per the original `StreamingResponse`
or else tuples of (`content`: `str`, `status_code`: `int`).
"""

body_iterator: AsyncIterator[str | bytes]
response_started: bool = False

async def stream_response(self, send: Send) -> None:
more_body = True
try:
first_chunk = await self.body_iterator.__anext__()
if isinstance(first_chunk, tuple):
first_chunk_content, self.status_code = first_chunk
else:
first_chunk_content, self.status_code = first_chunk, 200

if isinstance(first_chunk_content, str):
first_chunk_content = first_chunk_content.encode(self.charset)

await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})

self.response_started = True
await send({"type": "http.response.body", "body": first_chunk_content, "more_body": more_body})

async for chunk in self.body_iterator:
if isinstance(chunk, tuple):
content, status_code = chunk
if status_code // 100 != 2:
# an error occurred mid-stream
if not isinstance(content, bytes):
content = content.encode(self.charset)
more_body = False
await send({"type": "http.response.body", "body": content, "more_body": more_body})
return
else:
content = chunk

if isinstance(content, str):
content = content.encode(self.charset)
more_body = True
await send({"type": "http.response.body", "body": content, "more_body": more_body})

except Exception:
more_body = False
error_resp = {"error": {"message": "Internal Server Error"}}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
if not self.response_started:
await send({"type": "http.response.start", "status": 500, "headers": self.raw_headers})
await send({"type": "http.response.body", "body": error_event, "more_body": more_body})
if more_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})
8 changes: 8 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from app.helpers import MetricsMiddleware
from app.schemas.security import User
from app.utils.lifespan import lifespan
from app.utils.logging import logger
from app.utils.security import check_admin_api_key, check_api_key
from app.utils.settings import settings

Expand All @@ -20,6 +21,13 @@
redoc_url="/documentation",
)


@app.get("/")
async def root():
logger.info("Accès à la route principale")
return {"message": "Hello World"}


# Prometheus metrics
# @TODO: env_var_name="ENABLE_METRICS"
app.instrumentator = Instrumentator().instrument(app=app)
Expand Down
Loading

0 comments on commit 588e385

Please sign in to comment.