Skip to content

Commit

Permalink
Replace always_defer with defer_option.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Jan 31, 2025
1 parent 56f9ec2 commit 5cdd089
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
26 changes: 22 additions & 4 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from inspect import isasyncgenfunction
from typing import Annotated, Any, get_origin

Expand Down Expand Up @@ -38,6 +39,21 @@ def get_default_values(func):
}


class DeferOption(str, Enum):
"""Enum for defer option.
Attributes:
ALWAYS (str): Always defer. Endpoints with this option will always be defer execution to the task queue.
NEVER (str): Never defer. Endpoints with this option will never be defer execution to the task queue.
OPTIONAL (str): Optionally defer. Endpoints with this option can be defer execution to the task queue if
the defer query parameter is set to True.
"""

ALWAYS = "always"
NEVER = "never"
OPTIONAL = "optional"


@dataclass
class Endpoint:
"""Class used to represent an endpoint.
Expand All @@ -47,15 +63,15 @@ class Endpoint:
path (str): Path of the endpoint (e.g. "/video/transcribe").
summary (str): Description of the endpoint that will be shown in the API documentation.
admin_required (bool): Flag indicating if the endpoint requires admin access.
always_defer (bool): Flag indicating if the endpoint should always defer execution to the task queue.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler] | None): The list of event handlers to register for the endpoint.
"""

name: str
path: str
summary: str
admin_required: bool = False
always_defer: bool = False
defer_option: DeferOption = DeferOption.OPTIONAL
initialized: bool = False
event_handlers: list[EventHandler] | None = None

Expand Down Expand Up @@ -329,14 +345,16 @@ async def route_func(
description="Defer execution of the endpoint to the task queue.",
default=False,
include_in_schema=aana_settings.task_queue.enabled
and not self.always_defer,
and self.defer_option == DeferOption.OPTIONAL,
),
):
if aana_settings.api_service.enabled and self.admin_required:
check_admin_permissions(request)

if self.always_defer:
if self.defer_option == DeferOption.ALWAYS:
defer = True
elif self.defer_option == DeferOption.NEVER:
defer = False

form_data = await request.form()

Expand Down
8 changes: 4 additions & 4 deletions aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ray.serve.schema import ApplicationStatusOverview
from rich import print as rprint

from aana.api.api_generation import Endpoint
from aana.api.api_generation import DeferOption, Endpoint
from aana.api.event_handlers.event_handler import EventHandler
from aana.api.request_handler import RequestHandler
from aana.configs.settings import settings as aana_settings
Expand Down Expand Up @@ -267,7 +267,7 @@ def register_endpoint(
summary: str,
endpoint_cls: type[Endpoint],
admin_required: bool = False,
always_defer: bool = False,
defer_option: DeferOption = DeferOption.OPTIONAL,
event_handlers: list[EventHandler] | None = None,
):
"""Register an endpoint.
Expand All @@ -278,15 +278,15 @@ def register_endpoint(
summary (str): The summary of the endpoint.
endpoint_cls (Type[Endpoint]): The class of the endpoint.
admin_required (bool, optional): If True, the endpoint requires admin access. Defaults to False.
always_defer (bool, optional): If True, the endpoint will always defer execution to the task queue. Defaults to False.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler], optional): The event handlers to register for the endpoint.
"""
endpoint = endpoint_cls(
name=name,
path=path,
summary=summary,
admin_required=admin_required,
always_defer=always_defer,
defer_option=defer_option,
event_handlers=event_handlers,
)
self.endpoints[name] = endpoint
Expand Down

0 comments on commit 5cdd089

Please sign in to comment.