Skip to content

Commit

Permalink
Add OpenAI chat completions and task management endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Feb 4, 2025
1 parent abdf9c3 commit 6dacd94
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 119 deletions.
148 changes: 29 additions & 119 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import json
import time
from typing import Any
from uuid import UUID, uuid4
from uuid import UUID

import orjson
import ray
from fastapi import APIRouter
from fastapi.openapi.utils import get_openapi
from fastapi.responses import StreamingResponse
from ray import serve

from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema
Expand All @@ -16,17 +12,14 @@
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import AdminAccessDependency
from aana.api.task import router as task_router
from aana.api.webhook import (
from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse
from aana.routers.openai import router as openai_router
from aana.routers.task import router as task_router
from aana.routers.webhook import (
WebhookEventType,
trigger_task_webhooks,
)
from aana.api.webhook import router as webhook_router
from aana.configs.settings import settings as aana_settings
from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse
from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog
from aana.core.models.sampling import SamplingParams
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.routers.webhook import router as webhook_router
from aana.storage.models.task import Status as TaskStatus
from aana.storage.repository.task import TaskRepository
from aana.storage.session import get_session
Expand Down Expand Up @@ -59,9 +52,11 @@ def __init__(
self.deployments = deployments

# Include the default routers
app.include_router(webhook_router)
app.include_router(task_router)
# Include the custom routers
app.include_router(webhook_router) # For webhook management
app.include_router(task_router) # For task management
app.include_router(openai_router) # For OpenAI-compatible API

# Include the custom routers (from Aana Apps)
if routers is not None:
for router in routers:
app.include_router(router)
Expand Down Expand Up @@ -91,27 +86,6 @@ def custom_openapi(self) -> dict[str, Any]:
app.openapi_schema = openapi_schema
return app.openapi_schema

@app.get("/api/ready")
async def is_ready(self):
"""The endpoint for checking if the application is ready.
Real reason for this endpoint is to make automatic endpoint generation work.
If RequestHandler doesn't have any endpoints defined manually,
then the automatic endpoint generation doesn't work.
#TODO: Find a better solution for this.
Returns:
AanaJSONResponse: The response containing the ready status.
"""
return AanaJSONResponse(content={"ready": self.ready})

async def check_health(self):
"""Check the health of the application."""
# Heartbeat for the running tasks
with get_session() as session:
task_repo = TaskRepository(session)
task_repo.heartbeat(self.running_tasks)

async def execute_task(self, task_id: str | UUID) -> Any:
"""Execute a task.
Expand Down Expand Up @@ -165,92 +139,28 @@ async def execute_task(self, task_id: str | UUID) -> Any:
finally:
self.running_tasks.remove(task_id)

@app.post(
"/chat/completions",
response_model=ChatCompletion,
include_in_schema=aana_settings.openai_endpoint_enabled,
)
async def chat_completions(self, request: ChatCompletionRequest):
"""Handle chat completions requests for OpenAI compatible API."""
if not aana_settings.openai_endpoint_enabled:
return AanaJSONResponse(
content={
"error": {
"message": "The OpenAI-compatible endpoint is not enabled."
}
},
status_code=404,
)

async def _async_chat_completions(
handle: AanaDeploymentHandle,
dialog: ChatDialog,
sampling_params: SamplingParams,
):
async for response in handle.chat_stream(
dialog=dialog, sampling_params=sampling_params
):
chunk = {
"id": f"chatcmpl-{uuid4().hex}",
"object": "chat.completion.chunk",
"model": request.model,
"created": int(time.time()),
"choices": [
{
"index": 0,
"delta": {"content": response["text"], "role": "assistant"},
}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"

# Check if the deployment exists
try:
handle = await AanaDeploymentHandle.create(request.model)
except ray.serve.exceptions.RayServeException:
return AanaJSONResponse(
content={
"error": {"message": f"The model `{request.model}` does not exist."}
},
status_code=404,
)

# Check if the deployment is a chat model
if not hasattr(handle, "chat") or not hasattr(handle, "chat_stream"):
return AanaJSONResponse(
content={
"error": {"message": f"The model `{request.model}` does not exist."}
},
status_code=404,
)
@app.get("/api/ready", tags=["system"])
async def is_ready(self):
"""The endpoint for checking if the application is ready.
dialog = ChatDialog(
messages=request.messages,
)
Real reason for this endpoint is to make automatic endpoint generation work.
If RequestHandler doesn't have any endpoints defined manually,
then the automatic endpoint generation doesn't work.
#TODO: Find a better solution for this.
sampling_params = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
)
Returns:
AanaJSONResponse: The response containing the ready status.
"""
return AanaJSONResponse(content={"ready": self.ready})

if request.stream:
return StreamingResponse(
_async_chat_completions(handle, dialog, sampling_params),
media_type="application/x-ndjson",
)
else:
response = await handle.chat(dialog=dialog, sampling_params=sampling_params)
return {
"id": f"chatcmpl-{uuid4().hex}",
"object": "chat.completion",
"model": request.model,
"created": int(time.time()),
"choices": [{"index": 0, "message": response["message"]}],
}
async def check_health(self):
"""Check the health of the application."""
# Heartbeat for the running tasks
with get_session() as session:
task_repo = TaskRepository(session)
task_repo.heartbeat(self.running_tasks)

@app.get("/api/status", response_model=SDKStatusResponse)
@app.get("/api/status", response_model=SDKStatusResponse, tags=["system"])
async def status(self, is_admin: AdminAccessDependency) -> SDKStatusResponse:
"""The endpoint for checking the status of the application."""
app_names = [
Expand Down
Empty file added aana/routers/__init__.py
Empty file.
100 changes: 100 additions & 0 deletions aana/routers/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
import time
from uuid import uuid4

import ray
from fastapi import APIRouter
from fastapi.responses import StreamingResponse

from aana.api.responses import AanaJSONResponse
from aana.configs.settings import settings as aana_settings
from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog
from aana.core.models.sampling import SamplingParams
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle

router = APIRouter(
tags=["openai-api"], include_in_schema=aana_settings.openai_endpoint_enabled
)


@router.post(
"/chat/completions",
response_model=ChatCompletion,
)
async def chat_completions(request: ChatCompletionRequest):
"""Handle chat completions requests for OpenAI compatible API."""
if not aana_settings.openai_endpoint_enabled:
return AanaJSONResponse(
content={
"error": {"message": "The OpenAI-compatible endpoint is not enabled."}
},
status_code=404,
)

async def _async_chat_completions(
handle: AanaDeploymentHandle,
dialog: ChatDialog,
sampling_params: SamplingParams,
):
async for response in handle.chat_stream(
dialog=dialog, sampling_params=sampling_params
):
chunk = {
"id": f"chatcmpl-{uuid4().hex}",
"object": "chat.completion.chunk",
"model": request.model,
"created": int(time.time()),
"choices": [
{
"index": 0,
"delta": {"content": response["text"], "role": "assistant"},
}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"

# Check if the deployment exists
try:
handle = await AanaDeploymentHandle.create(request.model)
except ray.serve.exceptions.RayServeException:
return AanaJSONResponse(
content={
"error": {"message": f"The model `{request.model}` does not exist."}
},
status_code=404,
)

# Check if the deployment is a chat model
if not hasattr(handle, "chat") or not hasattr(handle, "chat_stream"):
return AanaJSONResponse(
content={
"error": {"message": f"The model `{request.model}` does not exist."}
},
status_code=404,
)

dialog = ChatDialog(
messages=request.messages,
)

sampling_params = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
)

if request.stream:
return StreamingResponse(
_async_chat_completions(handle, dialog, sampling_params),
media_type="application/x-ndjson",
)
else:
response = await handle.chat(dialog=dialog, sampling_params=sampling_params)
return {
"id": f"chatcmpl-{uuid4().hex}",
"object": "chat.completion",
"model": request.model,
"created": int(time.time()),
"choices": [{"index": 0, "message": response["message"]}],
}
Loading

0 comments on commit 6dacd94

Please sign in to comment.