Skip to content

Commit

Permalink
Invert the creation of API routes
Browse files Browse the repository at this point in the history
- Such that decode_access_token can be overriden
  when serving behind proxied OIDC
- Removes injection of password into security obj
- Removes use of dependency_override which is
  intended for use in tests
  • Loading branch information
DiamondJoseph committed Feb 5, 2025
1 parent adb7bcb commit fd664f0
Show file tree
Hide file tree
Showing 7 changed files with 2,068 additions and 2,146 deletions.
3 changes: 1 addition & 2 deletions tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,7 @@ def from_app(
# Extract the API key from the app and set it.
from ..server.settings import get_settings

settings = app.dependency_overrides[get_settings]()
api_key = settings.single_user_api_key or None
api_key = get_settings().single_user_api_key or None
else:
# This is a multi-user server but no API key was passed,
# so we will leave it as None on the Context.
Expand Down
197 changes: 63 additions & 134 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import urllib.parse
import warnings
from contextlib import asynccontextmanager
from functools import cache, partial
from functools import partial
from pathlib import Path
from typing import List
from typing import Any

import anyio
import packaging.version
Expand All @@ -34,30 +34,32 @@
HTTP_500_INTERNAL_SERVER_ERROR,
)

from tiled.server.authentication import (
get_current_principal_from_api_key,
session_state_getter,
)
from tiled.server.protocols import Authenticator

from ..config import construct_build_app_kwargs
from ..media_type_registration import CompressionRegistry, SerializationRegistry
from ..media_type_registration import (
compression_registry as default_compression_registry,
)
from ..media_type_registration import (
deserialization_registry as default_deserialization_registry,
)
from ..query_registration import QueryRegistry
from ..query_registration import query_registry as default_query_registry
from ..utils import SHARE_TILED_PATH, Conflicts, SpecialUsers, UnsupportedQueryType
from ..validation_registration import validation_registry as default_validation_registry
from . import schemas
from .authentication import get_current_principal
from .compression import CompressionMiddleware
from .dependencies import (
get_query_registry,
get_root_tree,
get_serialization_registry,
get_validation_registry,
)
from .router import distinct, patch_route_signature, router, search
from .router import get_router
from .settings import get_settings
from .utils import (
API_KEY_COOKIE_NAME,
CSRF_COOKIE_NAME,
get_authenticators,
get_root_url,
move_api_key,
record_timing,
)

Expand Down Expand Up @@ -113,9 +115,10 @@ def build_app(
tree,
authentication=None,
server_settings=None,
query_registry=None,
serialization_registry=None,
compression_registry=None,
query_registry: QueryRegistry | None = None,
serialization_registry: SerializationRegistry | None = None,
compression_registry: CompressionRegistry | None = None,
deserialization_registry: SerializationRegistry | None = None,
validation_registry=None,
tasks=None,
scalable=False,
Expand All @@ -138,10 +141,13 @@ def build_app(
spec["provider"]: spec["authenticator"]
for spec in authentication.get("providers", [])
}
server_settings = server_settings or {}
query_registry = query_registry or get_query_registry()
server_settings = server_settings or get_settings()
query_registry = query_registry or default_query_registry
compression_registry = compression_registry or default_compression_registry
validation_registry = validation_registry or default_validation_registry
deserialization_registry = (
deserialization_registry or default_deserialization_registry
)
tasks = tasks or {}
tasks.setdefault("startup", [])
tasks.setdefault("background", [])
Expand Down Expand Up @@ -265,9 +271,7 @@ async def lookup_file(path, try_app=True):
@app.get("/", response_class=HTMLResponse)
async def index(
request: Request,
# This dependency is here because it runs the code that moves
# API key from the query parameter to a cookie (if it is valid).
principal=Security(get_current_principal, scopes=[]),
_: str | None = Security(move_api_key),
):
return templates.TemplateResponse(
request,
Expand Down Expand Up @@ -348,99 +352,54 @@ async def unhandled_exception_handler(
),
)

app.include_router(router, prefix="/api/v1")

# The Tree and Authenticator have the opportunity to add custom routes to
# the server here. (Just for example, a Tree of BlueskyRuns uses this
# hook to add a /documents route.) This has to be done before dependency_overrides
# are processed, so we cannot just inject this configuration via Depends.
# hook to add a /documents route.)
for custom_router in getattr(tree, "include_routers", []):
app.include_router(custom_router, prefix="/api/v1")

if authenticators:
# Delay this imports to avoid delaying startup with the SQL and cryptography
# imports if they are not needed.
from .authentication import build_authentication_router
from .authentication import (
build_authentication_router,
current_principal_getter,
)

# For the OpenAPI schema, inject a OAuth2PasswordBearer URL.
first_provider = authentication["providers"][0]["provider"]
authentication_router = build_authentication_router(
authenticators, first_provider
authenticators, first_provider, server_settings
)
# And add this authentication_router itself to the app.
app.include_router(authentication_router, prefix="/api/v1/auth")

# The /search route is defined after import time so that the user has the
# opporunity to register custom query types before startup.
app.get(
"/api/v1/search/{path:path}",
response_model=schemas.Response[
List[schemas.Resource[schemas.NodeAttributes, dict, dict]],
schemas.PaginationLinks,
dict,
],
)(patch_route_signature(search, query_registry))
app.get(
"/api/v1/distinct/{path:path}",
response_model=schemas.GetDistinctResponse,
)(patch_route_signature(distinct, query_registry))

@cache
def override_get_authenticators():
return authenticators

@cache
def override_get_root_tree():
return tree

@cache
def override_get_settings():
settings = get_settings()
for item in [
"allow_anonymous_access",
"secret_keys",
"single_user_api_key",
"access_token_max_age",
"refresh_token_max_age",
"session_max_age",
]:
if authentication.get(item) is not None:
setattr(settings, item, authentication[item])
if authentication.get("single_user_api_key") is not None:
settings.single_user_api_key_generated = False
for item in [
"allow_origins",
"response_bytesize_limit",
"reject_undeclared_specs",
"expose_raw_assets",
]:
if server_settings.get(item) is not None:
setattr(settings, item, server_settings[item])
database = server_settings.get("database", {})
if database.get("uri"):
settings.database_uri = database["uri"]
if database.get("pool_size"):
settings.database_pool_size = database["pool_size"]
if database.get("pool_pre_ping"):
settings.database_pool_pre_ping = database["pool_pre_ping"]
if database.get("max_overflow"):
settings.database_max_overflow = database["max_overflow"]
if database.get("init_if_not_exists"):
settings.database_init_if_not_exists = database["init_if_not_exists"]
if authentication.get("providers"):
# If we support authentication providers, we need a database, so if one is
# not set, use a SQLite database in memory. Horizontally scaled deployments
# must specify a persistent database.
settings.database_uri = settings.database_uri or "sqlite://"
return settings
principal_getter = current_principal_getter(authenticators, server_settings)

else:
principal_getter = get_current_principal_from_api_key()

get_session_state = session_state_getter(authenticators, server_settings)

app.include_router(
get_router(
query_registry,
authenticators,
principal_getter,
server_settings.tree,
get_session_state,
serialization_registry,
deserialization_registry,
validation_registry,
),
prefix="/api/v1",
)

async def startup_event():
from .. import __version__

logger.info(f"Tiled version {__version__}")
# Validate the single-user API key.
settings = app.dependency_overrides[get_settings]()
single_user_api_key = settings.single_user_api_key
single_user_api_key = server_settings.single_user_api_key
API_KEY_MSG = """
Here are two ways to generate a good API key:
Expand Down Expand Up @@ -485,13 +444,12 @@ async def startup_event():
asyncio_task = asyncio.create_task(task())
app.state.tasks.append(asyncio_task)

app.state.allow_origins.extend(settings.allow_origins)
app.state.allow_origins.extend(server_settings.allow_origins)
# Expose the root_tree here to make it easier to access it from tests,
# in usages like:
# client.context.app.state.root_tree
app.state.root_tree = app.dependency_overrides[get_root_tree]()

if settings.database_uri is not None:
if server_settings.database_uri is not None:
from sqlalchemy.ext.asyncio import AsyncSession

from ..alembic_utils import (
Expand All @@ -512,7 +470,7 @@ async def startup_event():
# This creates a connection pool and stashes it in a module-global
# registry, keyed on database_settings, where can be retrieved by
# the Dependency get_database_session.
engine = open_database_connection_pool(settings.database_settings)
engine = open_database_connection_pool(server_settings.database_settings)
if not engine.url.database:
# Special-case for in-memory SQLite: Because it is transient we can
# skip over anything related to migrations.
Expand All @@ -523,7 +481,7 @@ async def startup_event():
try:
await check_database(engine, REQUIRED_REVISION, ALL_REVISIONS)
except UninitializedDatabase:
if settings.database_init_if_not_exists:
if server_settings.database_init_if_not_exists:
# The alembic stamping can only be does synchronously.
# The cleanest option available is to start a subprocess
# because SQLite is allergic to threads.
Expand Down Expand Up @@ -616,13 +574,12 @@ async def shutdown_event():
for task in tasks.get("shutdown", []):
await task()

settings = app.dependency_overrides[get_settings]()
if settings.database_uri is not None:
if server_settings.database_uri is not None:
from ..authn_database.connection_pool import close_database_connection_pool

for task in app.state.tasks:
task.cancel()
await close_database_connection_pool(settings.database_settings)
await close_database_connection_pool(server_settings.database_settings)

app.add_middleware(
CompressionMiddleware,
Expand Down Expand Up @@ -714,35 +671,6 @@ async def set_cookies(request: Request, call_next):
return response

app.openapi = partial(custom_openapi, app)
app.dependency_overrides[get_authenticators] = override_get_authenticators
app.dependency_overrides[get_root_tree] = override_get_root_tree
app.dependency_overrides[get_settings] = override_get_settings
if query_registry is not None:

@cache
def override_get_query_registry():
return query_registry

app.dependency_overrides[get_query_registry] = override_get_query_registry
if serialization_registry is not None:

@cache
def override_get_serialization_registry():
return serialization_registry

app.dependency_overrides[
get_serialization_registry
] = override_get_serialization_registry

if validation_registry is not None:

@cache
def override_get_validation_registry():
return validation_registry

app.dependency_overrides[
get_validation_registry
] = override_get_validation_registry

@app.middleware("http")
async def capture_metrics(request: Request, call_next):
Expand Down Expand Up @@ -883,15 +811,16 @@ def __getattr__(name):


def print_admin_api_key_if_generated(
web_app: FastAPI, host: str, port: int, force: bool = False
web_app: FastAPI,
host: str,
port: int,
authenticators: dict[str, Any] | None,
force: bool = False,
):
"Print message to stderr with API key if server-generated (or force=True)."
host = host or "127.0.0.1"
port = port or 8000
settings = web_app.dependency_overrides.get(get_settings, get_settings)()
authenticators = web_app.dependency_overrides.get(
get_authenticators, get_authenticators
)()
settings = get_settings()
if settings.allow_anonymous_access:
print(
"""
Expand Down
Loading

0 comments on commit fd664f0

Please sign in to comment.