diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index baeb2c66..f1dc6307 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,13 +1,9 @@ -import json -import time from typing import Any -from uuid import UUID, uuid4 +from uuid import UUID import orjson -import ray from fastapi import APIRouter from fastapi.openapi.utils import get_openapi -from fastapi.responses import StreamingResponse from ray import serve from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema @@ -16,17 +12,14 @@ from aana.api.exception_handler import custom_exception_handler from aana.api.responses import AanaJSONResponse from aana.api.security import AdminAccessDependency -from aana.api.task import router as task_router -from aana.api.webhook import ( +from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse +from aana.routers.openai import router as openai_router +from aana.routers.task import router as task_router +from aana.routers.webhook import ( WebhookEventType, trigger_task_webhooks, ) -from aana.api.webhook import router as webhook_router -from aana.configs.settings import settings as aana_settings -from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse -from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog -from aana.core.models.sampling import SamplingParams -from aana.deployments.aana_deployment_handle import AanaDeploymentHandle +from aana.routers.webhook import router as webhook_router from aana.storage.models.task import Status as TaskStatus from aana.storage.repository.task import TaskRepository from aana.storage.session import get_session @@ -59,9 +52,11 @@ def __init__( self.deployments = deployments # Include the default routers - app.include_router(webhook_router) - app.include_router(task_router) - # Include the custom routers + app.include_router(webhook_router) # For webhook management + app.include_router(task_router) # For task management + app.include_router(openai_router) # For OpenAI-compatible API + + # Include the custom routers (from Aana Apps) if routers is not None: for router in routers: app.include_router(router) @@ -91,27 +86,6 @@ def custom_openapi(self) -> dict[str, Any]: app.openapi_schema = openapi_schema return app.openapi_schema - @app.get("/api/ready") - async def is_ready(self): - """The endpoint for checking if the application is ready. - - Real reason for this endpoint is to make automatic endpoint generation work. - If RequestHandler doesn't have any endpoints defined manually, - then the automatic endpoint generation doesn't work. - #TODO: Find a better solution for this. - - Returns: - AanaJSONResponse: The response containing the ready status. - """ - return AanaJSONResponse(content={"ready": self.ready}) - - async def check_health(self): - """Check the health of the application.""" - # Heartbeat for the running tasks - with get_session() as session: - task_repo = TaskRepository(session) - task_repo.heartbeat(self.running_tasks) - async def execute_task(self, task_id: str | UUID) -> Any: """Execute a task. @@ -165,92 +139,28 @@ async def execute_task(self, task_id: str | UUID) -> Any: finally: self.running_tasks.remove(task_id) - @app.post( - "/chat/completions", - response_model=ChatCompletion, - include_in_schema=aana_settings.openai_endpoint_enabled, - ) - async def chat_completions(self, request: ChatCompletionRequest): - """Handle chat completions requests for OpenAI compatible API.""" - if not aana_settings.openai_endpoint_enabled: - return AanaJSONResponse( - content={ - "error": { - "message": "The OpenAI-compatible endpoint is not enabled." - } - }, - status_code=404, - ) - - async def _async_chat_completions( - handle: AanaDeploymentHandle, - dialog: ChatDialog, - sampling_params: SamplingParams, - ): - async for response in handle.chat_stream( - dialog=dialog, sampling_params=sampling_params - ): - chunk = { - "id": f"chatcmpl-{uuid4().hex}", - "object": "chat.completion.chunk", - "model": request.model, - "created": int(time.time()), - "choices": [ - { - "index": 0, - "delta": {"content": response["text"], "role": "assistant"}, - } - ], - } - yield f"data: {json.dumps(chunk)}\n\n" - yield "data: [DONE]\n\n" - - # Check if the deployment exists - try: - handle = await AanaDeploymentHandle.create(request.model) - except ray.serve.exceptions.RayServeException: - return AanaJSONResponse( - content={ - "error": {"message": f"The model `{request.model}` does not exist."} - }, - status_code=404, - ) - - # Check if the deployment is a chat model - if not hasattr(handle, "chat") or not hasattr(handle, "chat_stream"): - return AanaJSONResponse( - content={ - "error": {"message": f"The model `{request.model}` does not exist."} - }, - status_code=404, - ) + @app.get("/api/ready", tags=["system"]) + async def is_ready(self): + """The endpoint for checking if the application is ready. - dialog = ChatDialog( - messages=request.messages, - ) + Real reason for this endpoint is to make automatic endpoint generation work. + If RequestHandler doesn't have any endpoints defined manually, + then the automatic endpoint generation doesn't work. + #TODO: Find a better solution for this. - sampling_params = SamplingParams( - temperature=request.temperature, - max_tokens=request.max_tokens, - top_p=request.top_p, - ) + Returns: + AanaJSONResponse: The response containing the ready status. + """ + return AanaJSONResponse(content={"ready": self.ready}) - if request.stream: - return StreamingResponse( - _async_chat_completions(handle, dialog, sampling_params), - media_type="application/x-ndjson", - ) - else: - response = await handle.chat(dialog=dialog, sampling_params=sampling_params) - return { - "id": f"chatcmpl-{uuid4().hex}", - "object": "chat.completion", - "model": request.model, - "created": int(time.time()), - "choices": [{"index": 0, "message": response["message"]}], - } + async def check_health(self): + """Check the health of the application.""" + # Heartbeat for the running tasks + with get_session() as session: + task_repo = TaskRepository(session) + task_repo.heartbeat(self.running_tasks) - @app.get("/api/status", response_model=SDKStatusResponse) + @app.get("/api/status", response_model=SDKStatusResponse, tags=["system"]) async def status(self, is_admin: AdminAccessDependency) -> SDKStatusResponse: """The endpoint for checking the status of the application.""" app_names = [ diff --git a/aana/routers/__init__.py b/aana/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/routers/openai.py b/aana/routers/openai.py new file mode 100644 index 00000000..94f97c7a --- /dev/null +++ b/aana/routers/openai.py @@ -0,0 +1,100 @@ +import json +import time +from uuid import uuid4 + +import ray +from fastapi import APIRouter +from fastapi.responses import StreamingResponse + +from aana.api.responses import AanaJSONResponse +from aana.configs.settings import settings as aana_settings +from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog +from aana.core.models.sampling import SamplingParams +from aana.deployments.aana_deployment_handle import AanaDeploymentHandle + +router = APIRouter( + tags=["openai-api"], include_in_schema=aana_settings.openai_endpoint_enabled +) + + +@router.post( + "/chat/completions", + response_model=ChatCompletion, +) +async def chat_completions(request: ChatCompletionRequest): + """Handle chat completions requests for OpenAI compatible API.""" + if not aana_settings.openai_endpoint_enabled: + return AanaJSONResponse( + content={ + "error": {"message": "The OpenAI-compatible endpoint is not enabled."} + }, + status_code=404, + ) + + async def _async_chat_completions( + handle: AanaDeploymentHandle, + dialog: ChatDialog, + sampling_params: SamplingParams, + ): + async for response in handle.chat_stream( + dialog=dialog, sampling_params=sampling_params + ): + chunk = { + "id": f"chatcmpl-{uuid4().hex}", + "object": "chat.completion.chunk", + "model": request.model, + "created": int(time.time()), + "choices": [ + { + "index": 0, + "delta": {"content": response["text"], "role": "assistant"}, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + # Check if the deployment exists + try: + handle = await AanaDeploymentHandle.create(request.model) + except ray.serve.exceptions.RayServeException: + return AanaJSONResponse( + content={ + "error": {"message": f"The model `{request.model}` does not exist."} + }, + status_code=404, + ) + + # Check if the deployment is a chat model + if not hasattr(handle, "chat") or not hasattr(handle, "chat_stream"): + return AanaJSONResponse( + content={ + "error": {"message": f"The model `{request.model}` does not exist."} + }, + status_code=404, + ) + + dialog = ChatDialog( + messages=request.messages, + ) + + sampling_params = SamplingParams( + temperature=request.temperature, + max_tokens=request.max_tokens, + top_p=request.top_p, + ) + + if request.stream: + return StreamingResponse( + _async_chat_completions(handle, dialog, sampling_params), + media_type="application/x-ndjson", + ) + else: + response = await handle.chat(dialog=dialog, sampling_params=sampling_params) + return { + "id": f"chatcmpl-{uuid4().hex}", + "object": "chat.completion", + "model": request.model, + "created": int(time.time()), + "choices": [{"index": 0, "message": response["message"]}], + } diff --git a/aana/routers/task.py b/aana/routers/task.py new file mode 100644 index 00000000..8b7ccf0d --- /dev/null +++ b/aana/routers/task.py @@ -0,0 +1,157 @@ +import logging +from typing import Annotated + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field + +from aana.api.security import UserIdDependency +from aana.configs.settings import settings as aana_settings +from aana.core.models.task import TaskInfo +from aana.storage.models.task import Status as TaskStatus +from aana.storage.repository.task import TaskRepository +from aana.storage.session import GetDbDependency + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["tasks"], include_in_schema=aana_settings.task_queue.enabled) + +# Request models + + +class TaskListRequest(BaseModel): + """Request to list tasks.""" + + status: TaskStatus | None = Field( + None, description="Filter tasks by status. If None, all tasks are returned." + ) + limit: int = Field(100, description="The maximum number of tasks to return.") + offset: int = Field( + 0, description="The number of tasks to skip before starting to return tasks." + ) + + +# Response models + + +class TaskList(BaseModel): + """Response for a list of tasks.""" + + tasks: list[TaskInfo] = Field(..., description="The list of tasks.") + + +# Endpoints + + +@router.get( + "/tasks/{task_id}", + summary="Get Task Status", + description="Get the task status by task ID.", +) +async def get_task( + task_id: str, db: GetDbDependency, user_id: UserIdDependency +) -> TaskInfo: + """Get the task with the given ID.""" + task_repo = TaskRepository(db) + task = task_repo.read(task_id, check=False) + if not task or task.user_id != user_id: + raise HTTPException( + status_code=404, + detail="Task not found", + ) + return TaskInfo.from_entity(task) + + +@router.get( + "/tasks", + summary="List Tasks", + description="List all tasks.", +) +async def list_tasks( + db: GetDbDependency, + user_id: UserIdDependency, + status: Annotated[ + TaskStatus | None, + Field(description="Filter tasks by status. If None, all tasks are returned."), + ] = None, + page: Annotated[int, Field(description="The page number.")] = 1, + per_page: Annotated[int, Field(description="The number of tasks per page.")] = 100, +) -> TaskList: + """List all tasks.""" + task_repo = TaskRepository(db) + tasks = task_repo.get_tasks( + user_id=user_id, status=status, limit=per_page, offset=(page - 1) * per_page + ) + return TaskList(tasks=[TaskInfo.from_entity(task) for task in tasks]) + + +@router.delete( + "/tasks/{task_id}", + summary="Delete Task", + description="Delete the task by task ID.", +) +async def delete_task( + task_id: str, db: GetDbDependency, user_id: UserIdDependency +) -> TaskInfo: + """Delete the task with the given ID.""" + task_repo = TaskRepository(db) + task = task_repo.read(task_id, check=False) + if not task or task.user_id != user_id: + raise HTTPException( + status_code=404, + detail="Task not found", + ) + task = task_repo.delete(task_id) + if task is None: + raise HTTPException( + status_code=404, + detail="Task not found", + ) + return TaskInfo.from_entity(task) + + +# Legacy endpoints (to be removed in the future) + + +@router.get( + "/tasks/get/{task_id}", + summary="Get Task Status (Legacy)", + description="Get the task status by task ID (Legacy endpoint).", + deprecated=True, +) +async def get_task_legacy( + task_id: str, db: GetDbDependency, user_id: UserIdDependency +) -> TaskInfo: + """Get the task with the given ID (Legacy endpoint).""" + task_repo = TaskRepository(db) + task = task_repo.read(task_id) + if not task or task.user_id != user_id: + raise HTTPException(status_code=404, detail="Task not found") + return TaskInfo.from_entity(task) + + +class TaskId(BaseModel): + """Task ID (Legacy). + + Attributes: + id (str): The task ID. + """ + + task_id: str = Field(..., description="The task ID.") + + +@router.get( + "/tasks/delete/{task_id}", + summary="Delete Task (Legacy)", + description="Delete the task by task ID (Legacy endpoint).", + deprecated=True, +) +async def delete_task_legacy( + task_id: str, db: GetDbDependency, user_id: UserIdDependency +) -> TaskId: + """Delete the task with the given ID (Legacy endpoint).""" + task_repo = TaskRepository(db) + task = task_repo.read(task_id) + if not task or task.user_id != user_id: + raise HTTPException(status_code=404, detail="Task not found") + task = task_repo.delete(task_id) + return TaskId(task_id=str(task.id)) diff --git a/aana/api/webhook.py b/aana/routers/webhook.py similarity index 100% rename from aana/api/webhook.py rename to aana/routers/webhook.py