Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance API Key Validation and Admin Access Control #224

Merged
merged 8 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from inspect import isasyncgenfunction
from typing import Annotated, Any, get_origin

Expand All @@ -18,6 +19,7 @@
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import check_admin_permissions
from aana.configs.settings import settings as aana_settings
from aana.core.models.api_service import ApiKey
from aana.core.models.exception import ExceptionResponseModel
Expand All @@ -37,6 +39,21 @@ def get_default_values(func):
}


class DeferOption(str, Enum):
"""Enum for defer option.

Attributes:
ALWAYS (str): Always defer. Endpoints with this option will always be defer execution to the task queue.
NEVER (str): Never defer. Endpoints with this option will never be defer execution to the task queue.
OPTIONAL (str): Optionally defer. Endpoints with this option can be defer execution to the task queue if
the defer query parameter is set to True.
"""

ALWAYS = "always"
NEVER = "never"
OPTIONAL = "optional"


@dataclass
class Endpoint:
"""Class used to represent an endpoint.
Expand All @@ -45,12 +62,16 @@ class Endpoint:
name (str): Name of the endpoint.
path (str): Path of the endpoint (e.g. "/video/transcribe").
summary (str): Description of the endpoint that will be shown in the API documentation.
admin_required (bool): Flag indicating if the endpoint requires admin access.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler] | None): The list of event handlers to register for the endpoint.
"""

name: str
path: str
summary: str
admin_required: bool = False
defer_option: DeferOption = DeferOption.OPTIONAL
initialized: bool = False
event_handlers: list[EventHandler] | None = None

Expand Down Expand Up @@ -323,9 +344,18 @@ async def route_func(
defer: bool = Query(
description="Defer execution of the endpoint to the task queue.",
default=False,
include_in_schema=aana_settings.task_queue.enabled,
include_in_schema=aana_settings.task_queue.enabled
and self.defer_option == DeferOption.OPTIONAL,
),
):
if aana_settings.api_service.enabled and self.admin_required:
check_admin_permissions(request)

if self.defer_option == DeferOption.ALWAYS:
defer = True
elif self.defer_option == DeferOption.NEVER:
defer = False

form_data = await request.form()

# Parse files from the form data
Expand Down
15 changes: 11 additions & 4 deletions aana/api/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from fastapi import FastAPI, Request
from pydantic import ValidationError

Expand All @@ -10,6 +9,7 @@
from aana.exceptions.api_service import (
ApiKeyNotFound,
ApiKeyNotProvided,
ApiKeyValidationFailed,
InactiveSubscription,
)
from aana.storage.models.api_key import ApiKeyEntity
Expand All @@ -24,16 +24,23 @@
@app.middleware("http")
async def api_key_check(request: Request, call_next):
"""Middleware to check the API key and subscription status."""
excluded_paths = ["/openapi.json", "/docs", "/redoc"]
if request.url.path in excluded_paths:
return await call_next(request)

if aana_settings.api_service.enabled:
api_key = request.headers.get("x-api-key")

if not api_key:
raise ApiKeyNotProvided()

with get_session() as session:
api_key_info = (
session.query(ApiKeyEntity).filter_by(api_key=api_key).first()
)
try:
api_key_info = (
session.query(ApiKeyEntity).filter_by(api_key=api_key).first()
)
except Exception as e:
raise ApiKeyValidationFailed() from e

if not api_key_info:
raise ApiKeyNotFound(key=api_key)
Expand Down
18 changes: 16 additions & 2 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import AdminRequired
from aana.configs.settings import settings as aana_settings
from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse
from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog
Expand Down Expand Up @@ -201,9 +202,22 @@ async def delete_task_endpoint(
task = task_repo.delete(task_id)
return TaskId(task_id=str(task.id))

@app.post("/chat/completions", response_model=ChatCompletion)
@app.post(
"/chat/completions",
response_model=ChatCompletion,
include_in_schema=aana_settings.openai_endpoint_enabled,
)
async def chat_completions(self, request: ChatCompletionRequest):
"""Handle chat completions requests for OpenAI compatible API."""
if not aana_settings.openai_endpoint_enabled:
return AanaJSONResponse(
content={
"error": {
"message": "The OpenAI-compatible endpoint is not enabled."
}
},
status_code=404,
)

async def _async_chat_completions(
handle: AanaDeploymentHandle,
Expand Down Expand Up @@ -274,7 +288,7 @@ async def _async_chat_completions(
}

@app.get("/api/status", response_model=SDKStatusResponse)
async def status(self) -> SDKStatusResponse:
async def status(self, is_admin: AdminRequired) -> SDKStatusResponse:
"""The endpoint for checking the status of the application."""
app_names = [
self.app_name,
Expand Down
35 changes: 35 additions & 0 deletions aana/api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Annotated

from fastapi import Depends, Request

from aana.configs.settings import settings as aana_settings
from aana.exceptions.api_service import AdminOnlyAccess


def check_admin_permissions(request: Request):
"""Check if the user is an admin.

Args:
request (Request): The request object

Raises:
AdminOnlyAccess: If the user is not an admin
"""
if aana_settings.api_service.enabled:
api_key_info = request.state.api_key_info
is_admin = api_key_info.get("is_admin", False)
if not is_admin:
raise AdminOnlyAccess()


class AdminCheck:
"""Dependency to check if the user is an admin."""

async def __call__(self, request: Request) -> bool:
"""Check if the user is an admin."""
check_admin_permissions(request)
return True


AdminRequired = Annotated[bool, Depends(AdminCheck())]
""" Annotation to check if the user is an admin. If not, it will raise an exception. """
3 changes: 3 additions & 0 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Settings(BaseSettings):
audio_dir (Path): The temporary audio directory.
model_dir (Path): The temporary model directory.
num_workers (int): The number of web workers.
openai_endpoint_enabled (bool): Flag indicating if the OpenAI-compatible endpoint is enabled. Enabled by default.
task_queue (TaskQueueSettings): The task queue settings.
db_config (DbSettings): The database configuration.
test (TestSettings): The test settings.
Expand All @@ -79,6 +80,8 @@ class Settings(BaseSettings):

num_workers: int = 2

openai_endpoint_enabled: bool = True

task_queue: TaskQueueSettings = TaskQueueSettings()

db_config: DbSettings = DbSettings()
Expand Down
26 changes: 26 additions & 0 deletions aana/exceptions/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,29 @@ def __init__(self, key: str):
def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.key,))


class AdminOnlyAccess(BaseException):
"""Exception raised when the user does not have enough permissions."""

def __init__(self):
"""Initialize the exception."""
self.message = "Admin only access"
super().__init__(message=self.message)

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, ())


class ApiKeyValidationFailed(BaseException):
"""Exception raised when the API key validation fails."""

def __init__(self):
"""Initialize the exception."""
self.message = "API key validation failed"
super().__init__(message=self.message)

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, ())
8 changes: 7 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ray.serve.schema import ApplicationStatusOverview
from rich import print as rprint

from aana.api.api_generation import Endpoint
from aana.api.api_generation import DeferOption, Endpoint
from aana.api.event_handlers.event_handler import EventHandler
from aana.api.request_handler import RequestHandler
from aana.configs.settings import settings as aana_settings
Expand Down Expand Up @@ -266,6 +266,8 @@ def register_endpoint(
path: str,
summary: str,
endpoint_cls: type[Endpoint],
admin_required: bool = False,
defer_option: DeferOption = DeferOption.OPTIONAL,
event_handlers: list[EventHandler] | None = None,
):
"""Register an endpoint.
Expand All @@ -275,12 +277,16 @@ def register_endpoint(
path (str): The path of the endpoint.
summary (str): The summary of the endpoint.
endpoint_cls (Type[Endpoint]): The class of the endpoint.
admin_required (bool, optional): If True, the endpoint requires admin access. Defaults to False.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler], optional): The event handlers to register for the endpoint.
"""
endpoint = endpoint_cls(
name=name,
path=path,
summary=summary,
admin_required=admin_required,
defer_option=defer_option,
event_handlers=event_handlers,
)
self.endpoints[name] = endpoint
Expand Down
5 changes: 5 additions & 0 deletions aana/storage/models/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class ApiKeyEntity(ApiServiceBase):
user_id: Mapped[str] = mapped_column(
nullable=False, comment="ID of the user who owns this API key"
)
is_admin: Mapped[bool] = mapped_column(
nullable=False, default=False, comment="Whether the user is an admin"
)
subscription_id: Mapped[str] = mapped_column(
nullable=False, comment="ID of the associated subscription"
)
Expand All @@ -34,7 +37,9 @@ def __repr__(self) -> str:
"""String representation of the API key."""
return (
f"<APIKeyEntity(id={self.id}, "
f"api_key={self.api_key}, "
f"user_id={self.user_id}, "
f"is_admin={self.is_admin}, "
f"subscription_id={self.subscription_id}, "
f"is_subscription_active={self.is_subscription_active})>"
)
Expand Down
3 changes: 3 additions & 0 deletions docs/pages/openai_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to i

Chat Completions API is available at the `/chat/completions` endpoint.

!!! Tip
The endpoint is enabled by default but can be disabled by setting the environment variable: `OPENAI_ENDPOINT_ENABLED=False`.

It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API.

```python
Expand Down