diff --git a/example_configs/external_service/custom.py b/example_configs/external_service/custom.py index 01d528716..1320de1ad 100644 --- a/example_configs/external_service/custom.py +++ b/example_configs/external_service/custom.py @@ -1,25 +1,16 @@ import numpy +from pydantic import Secret from tiled.adapters.array import ArrayAdapter -from tiled.authenticators import Mode, UserSessionState from tiled.structures.core import StructureFamily -class Authenticator: - "This accepts any password and stashes it in session state as 'token'." - mode = Mode.password - - async def authenticate(self, username: str, password: str) -> UserSessionState: - return UserSessionState(username, {"token": password}) - - -# This stands in for a secret token issued by the external service. -SERVICE_ISSUED_TOKEN = "secret" - - class MockClient: - def __init__(self, base_url): - self.base_url = base_url + + def __init__(self, base_url: str, example_token: str = "secret"): + self._base_url = base_url + # This stands in for a secret token issued by the external service. + self._example_token = Secret(example_token) # This API (get_contents, get_metadata, get_data) is just made up and not important. # Could be anything. @@ -27,19 +18,19 @@ def __init__(self, base_url): async def get_metadata(self, url, token): # This assert stands in for the mocked service # authenticating a request. - assert token == SERVICE_ISSUED_TOKEN + assert token == self._example_token.get_secret_value() return {"metadata": str(url)} async def get_contents(self, url, token): # This assert stands in for the mocked service # authenticating a request. - assert token == SERVICE_ISSUED_TOKEN + assert token == self._example_token.get_secret_value() return ["a", "b", "c"] async def get_data(self, url, token): # This assert stands in for the mocked service # authenticating a request. - assert token == SERVICE_ISSUED_TOKEN + assert token == self._example_token.get_secret_value() return numpy.ones((3, 3)) diff --git a/example_configs/mock-oidc-server.yml b/example_configs/mock-oidc-server.yml index c8531f4cd..762e5ba58 100644 --- a/example_configs/mock-oidc-server.yml +++ b/example_configs/mock-oidc-server.yml @@ -9,7 +9,9 @@ authentication: client_secret: secret well_known_uri: http://localhost:8080/.well-known/openid-configuration trees: - # Just some arbitrary example data... - # The point of this example is the authenticaiton above. - - tree: tiled.examples.generated_minimal:tree - path: / + - path: / + tree: catalog + args: + uri: "sqlite+aiosqlite:///:memory:" + writable_storage: "/tmp/data" + init_if_not_exists: true diff --git a/tiled/authenticators.py b/tiled/authenticators.py index 759ea91e5..d838ee7fe 100644 --- a/tiled/authenticators.py +++ b/tiled/authenticators.py @@ -9,28 +9,24 @@ import httpx from fastapi import APIRouter, Request -from jose import JWTError, jwk, jwt +from jose import JWTError, jwt from pydantic import Secret from starlette.responses import RedirectResponse -from .server.authentication import Mode -from .server.protocols import UserSessionState +from .server.protocols import ExternalAuthenticator, UserSessionState, PasswordAuthenticator from .server.utils import get_root_url from .utils import modules_available logger = logging.getLogger(__name__) -class DummyAuthenticator: +class DummyAuthenticator(PasswordAuthenticator): """ For test and demo purposes only! Accept any username and any password. """ - - mode = Mode.password - def __init__(self, confirmation_message=""): self.confirmation_message = confirmation_message @@ -38,14 +34,12 @@ async def authenticate(self, username: str, password: str) -> UserSessionState: return UserSessionState(username, {}) -class DictionaryAuthenticator: +class DictionaryAuthenticator(PasswordAuthenticator): """ For test and demo purposes only! Check passwords from a dictionary of usernames mapped to passwords. """ - - mode = Mode.password configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -74,8 +68,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState: return UserSessionState(username, {}) -class PAMAuthenticator: - mode = Mode.password +class PAMAuthenticator(PasswordAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -110,8 +103,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState: return UserSessionState(username, {}) -class OIDCAuthenticator: - mode = Mode.external +class OIDCAuthenticator(ExternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -164,7 +156,7 @@ def jwks_uri(self) -> str: def token_endpoint(self) -> str: return cast(str, self._config_from_oidc_url.get("token_endpoint")) - async def authenticate(self, request: Request) -> UserSessionState: + async def authenticate(self, request: Request) -> UserSessionState | None: code = request.query_params["code"] # A proxy in the middle may make the request into something like # 'http://localhost:8000/...' so we fix the first part but keep @@ -228,8 +220,7 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect return response -class SAMLAuthenticator: - mode = Mode.external +class SAMLAuthenticator(ExternalAuthenticator): def __init__( self, @@ -271,7 +262,7 @@ async def saml_login(request: Request): self.include_routers = [router] - async def authenticate(self, request) -> UserSessionState: + async def authenticate(self, request) -> UserSessionState | None: if not modules_available("onelogin"): raise ModuleNotFoundError( "This SAMLAuthenticator requires the module 'oneline' to be installed." @@ -323,7 +314,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False): return rv -class LDAPAuthenticator: +class LDAPAuthenticator(PasswordAuthenticator): """ The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator The parameter ``use_tls`` was added for convenience of testing. @@ -506,8 +497,6 @@ class LDAPAuthenticator: id: user02 """ - mode = Mode.password - def __init__( self, server_address, diff --git a/tiled/authn_database/connection_pool.py b/tiled/authn_database/connection_pool.py index 4eae74b83..4941eb965 100644 --- a/tiled/authn_database/connection_pool.py +++ b/tiled/authn_database/connection_pool.py @@ -1,7 +1,7 @@ from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from ..server.settings import get_settings +from ..server.settings import Settings, get_settings from ..utils import ensure_specified_sql_driver # A given process probably only has one of these at a time, but we @@ -31,7 +31,7 @@ async def close_database_connection_pool(database_settings): await engine.dispose() -async def get_database_engine(settings=Depends(get_settings)): +async def get_database_engine(settings: Settings = Depends(get_settings)): # Special case for single-user mode if settings.database_uri is None: return None diff --git a/tiled/client/context.py b/tiled/client/context.py index a1be82d08..bb069a4d0 100644 --- a/tiled/client/context.py +++ b/tiled/client/context.py @@ -8,6 +8,7 @@ from pathlib import Path from urllib.parse import parse_qs, urlparse +from fastapi import FastAPI import httpx import platformdirs @@ -414,7 +415,7 @@ def from_any_uri( @classmethod def from_app( cls, - app, + app: FastAPI, *, cache=UNSET, headers=None, @@ -438,9 +439,9 @@ def from_app( if not context.server_info["authentication"]["providers"]: # This is a single-user server. # Extract the API key from the app and set it. - from ..server.settings import get_settings + from ..server.settings import get_settings, Settings - settings = app.dependency_overrides[get_settings]() + settings: Settings = app.dependency_overrides[get_settings]() api_key = settings.single_user_api_key or None else: # This is a multi-user server but no API key was passed, diff --git a/tiled/server/app.py b/tiled/server/app.py index 3a5846063..10d64e001 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -8,9 +8,9 @@ import urllib.parse import warnings from contextlib import asynccontextmanager -from functools import lru_cache, partial +from functools import cache, partial from pathlib import Path -from typing import List +from typing import Any, Dict, List import anyio import packaging.version @@ -34,7 +34,8 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) -from ..authenticators import Mode +from tiled.server.protocols import Authenticator, ExternalAuthenticator, PasswordAuthenticator + from ..config import construct_build_app_kwargs from ..media_type_registration import ( compression_registry as default_compression_registry, @@ -51,7 +52,7 @@ get_validation_registry, ) from .router import distinct, patch_route_signature, router, search -from .settings import get_settings +from .settings import Settings, get_settings from .utils import ( API_KEY_COOKIE_NAME, CSRF_COOKIE_NAME, @@ -81,7 +82,7 @@ current_principal = contextvars.ContextVar("current_principal") -def custom_openapi(app: FastAPI): +def custom_openapi(app: FastAPI) -> Dict[str, Any]: """ The app's openapi method will be monkey-patched with this. @@ -118,7 +119,7 @@ def build_app( validation_registry=None, tasks=None, scalable=False, -): +) -> FastAPI: """ Serve a Tree @@ -133,7 +134,7 @@ def build_app( Dict of other server configuration. """ authentication = authentication or {} - authenticators = { + authenticators: Dict[str, Authenticator] = { spec["provider"]: spec["authenticator"] for spec in authentication.get("providers", []) } @@ -385,12 +386,11 @@ async def unhandled_exception_handler( for spec in authentication["providers"]: provider = spec["provider"] authenticator = spec["authenticator"] - mode = authenticator.mode - if mode == Mode.password: + if isinstance(authenticator, PasswordAuthenticator): authentication_router.post(f"/provider/{provider}/token")( build_handle_credentials_route(authenticator, provider) ) - elif mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): # Client starts here to create a PendingSession. authentication_router.post(f"/provider/{provider}/authorize")( build_device_code_authorize_route(authenticator, provider) @@ -415,7 +415,7 @@ async def unhandled_exception_handler( # build_auth_code_route(authenticator, provider) # ) else: - raise ValueError(f"unknown authentication mode {mode}") + raise ValueError(f"Unexpected authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): authentication_router.include_router( custom_router, prefix=f"/provider/{provider}" @@ -438,15 +438,16 @@ async def unhandled_exception_handler( response_model=schemas.GetDistinctResponse, )(patch_route_signature(distinct, query_registry)) - @lru_cache(1) - def override_get_authenticators(): + @cache + def override_get_authenticators() -> Dict[str, Authenticator]: + print(authenticators) return authenticators - @lru_cache(1) + @cache def override_get_root_tree(): return tree - @lru_cache(1) + @cache def override_get_settings(): settings = get_settings() for item in [ @@ -492,7 +493,7 @@ async def startup_event(): logger.info(f"Tiled version {__version__}") # Validate the single-user API key. - settings = app.dependency_overrides[get_settings]() + settings: Settings = app.dependency_overrides[get_settings]() single_user_api_key = settings.single_user_api_key API_KEY_MSG = """ Here are two ways to generate a good API key: @@ -772,14 +773,14 @@ async def set_cookies(request: Request, call_next): app.dependency_overrides[get_settings] = override_get_settings if query_registry is not None: - @lru_cache(1) + @cache def override_get_query_registry(): return query_registry app.dependency_overrides[get_query_registry] = override_get_query_registry if serialization_registry is not None: - @lru_cache(1) + @cache def override_get_serialization_registry(): return serialization_registry @@ -789,7 +790,7 @@ def override_get_serialization_registry(): if validation_registry is not None: - @lru_cache(1) + @cache def override_get_validation_registry(): return validation_registry @@ -936,13 +937,14 @@ def __getattr__(name): def print_admin_api_key_if_generated( - web_app: FastAPI, host: str, port: int, force: bool = False + web_app: FastAPI, + host: str = "127.0.0.1", + port: int = 8000, + force: bool = False ): "Print message to stderr with API key if server-generated (or force=True)." - host = host or "127.0.0.1" - port = port or 8000 - settings = web_app.dependency_overrides.get(get_settings, get_settings)() - authenticators = web_app.dependency_overrides.get( + settings: Settings = web_app.dependency_overrides.get(get_settings, get_settings)() + authenticators: Dict[str, Authenticator] = web_app.dependency_overrides.get( get_authenticators, get_authenticators )() if settings.allow_anonymous_access: diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index 19b9ad9b7..c69f59632 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -1,12 +1,12 @@ -import enum import hashlib import secrets import uuid as uuid_module import warnings from datetime import datetime, timedelta from pathlib import Path -from typing import Optional +from typing import Dict, Optional, cast +import httpx import sqlalchemy.exc from fastapi import ( APIRouter, @@ -38,6 +38,8 @@ HTTP_409_CONFLICT, ) +from tiled.authenticators import OIDCAuthenticator + # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: # int_from_bytes is deprecated, use int.from_bytes instead @@ -61,7 +63,7 @@ from ..utils import SHARE_TILED_PATH, SpecialUsers from . import schemas from .core import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, json_or_msgpack -from .protocols import UsernamePasswordAuthenticator, UserSessionState +from .protocols import Authenticator, ExternalAuthenticator, PasswordAuthenticator, UserSessionState from .settings import Settings, get_settings from .utils import API_KEY_COOKIE_NAME, get_authenticators, get_base_url @@ -85,11 +87,6 @@ def utcnow(): return datetime.utcnow().replace(microsecond=0) -class Mode(enum.Enum): - password = "password" - external = "external" - - class Token(BaseModel): access_token: str token_type: str @@ -166,6 +163,10 @@ def create_refresh_token(session_id, secret_key, expires_delta): return encoded_jwt +def decode_oidc_token(token: str, authentictor: OIDCAuthenticator): + return jwt.decode(token, httpx.get(authentictor.jwks_uri), algorithms=[ALGORITHM]) + + def decode_token(token: str, secret_keys: list[str]): credentials_exception = HTTPException( status_code=HTTP_401_UNAUTHORIZED, @@ -177,12 +178,15 @@ def decode_token(token: str, secret_keys: list[str]): # fail. They supports key rotation. for secret_key in secret_keys: try: + """ DO NOT MERGE! """ + print(secret_key) # Remove this!!!!!! payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) break except ExpiredSignatureError: # Do not let this be caught below with the other JWTError types. raise - except JWTError: + except JWTError as e: + print(e) # Try the next key in the key rotation. continue else: @@ -221,11 +225,15 @@ async def get_decoded_access_token( access_token: str = Depends(oauth2_scheme), settings: Settings = Depends(get_settings), ): - print("Got access_token") if not access_token: return None try: - payload = decode_token(access_token, settings.secret_keys) + print(settings.authenticator) + if isinstance(settings.authenticator, OIDCAuthenticator): + payload = decode_oidc_token(access_token, settings.authenticator) + print("proof of concept!") + else: + payload = decode_token(access_token, settings.secret_keys) except ExpiredSignatureError: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, @@ -246,7 +254,7 @@ async def get_current_principal( decoded_access_token: str = Depends(get_decoded_access_token), api_key: str = Depends(get_api_key), settings: Settings = Depends(get_settings), - authenticators=Depends(get_authenticators), + authenticators: Dict[str, Authenticator] = Depends(get_authenticators), db=Depends(get_database_session), ): """ @@ -510,7 +518,7 @@ async def create_tokens_from_session(settings: Settings, db, session, provider): } -def build_auth_code_route(authenticator, provider): +def build_auth_code_route(authenticator: ExternalAuthenticator, provider): "Build an auth_code route function for this Authenticator." async def route( @@ -524,6 +532,7 @@ async def route( raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Authentication failure" ) + user_session_state = cast(UserSessionState, user_session_state) session = await create_session( settings, db, @@ -537,7 +546,7 @@ async def route( return route -def build_device_code_authorize_route(authenticator, provider): +def build_device_code_authorize_route(authenticator: ExternalAuthenticator, provider): "Build an /authorize route function for this Authenticator." async def route( @@ -571,7 +580,7 @@ async def route( return route -def build_device_code_user_code_form_route(authentication, provider): +def build_device_code_user_code_form_route(authentication: ExternalAuthenticator, provider): if not SHARE_TILED_PATH: raise Exception( "Static assets could not be found and are required for " @@ -598,7 +607,7 @@ async def route( return route -def build_device_code_user_code_submit_route(authenticator, provider): +def build_device_code_user_code_submit_route(authenticator: ExternalAuthenticator, provider): "Build an /authorize route function for this Authenticator." if not SHARE_TILED_PATH: @@ -670,7 +679,7 @@ async def route( return route -def build_device_code_token_route(authenticator, provider): +def build_device_code_token_route(authenticator: ExternalAuthenticator, provider): "Build an /authorize route function for this Authenticator." async def route( @@ -711,7 +720,7 @@ async def route( def build_handle_credentials_route( - authenticator: UsernamePasswordAuthenticator, provider + authenticator: PasswordAuthenticator, provider ): "Register a handle_credentials route function for this Authenticator." @@ -988,7 +997,11 @@ async def revoke_session( ): "Mark a Session as revoked so it cannot be refreshed again." request.state.endpoint = "auth" - payload = decode_token(refresh_token.refresh_token, settings.secret_keys) + if isinstance(settings.authenticator, OIDCAuthenticator): + payload = decode_oidc_token(refresh_token.refresh_token, settings.authenticator) + print("proof of concept!") + else: + payload = decode_token(refresh_token.refresh_token, settings.secret_keys) session_id = payload["sid"] # Find this session in the database. session = await lookup_valid_session(db, session_id) @@ -1027,7 +1040,11 @@ async def revoke_session_by_id( async def slide_session(refresh_token, settings: Settings, db): try: - payload = decode_token(refresh_token, settings.secret_keys) + if isinstance(settings.authenticator, OIDCAuthenticator): + payload = decode_oidc_token(refresh_token, settings.authenticator) + print("proof of concept!") + else: + payload = decode_token(refresh_token, settings.secret_keys) except ExpiredSignatureError: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, diff --git a/tiled/server/protocols.py b/tiled/server/protocols.py index 232a5145a..728f40d00 100644 --- a/tiled/server/protocols.py +++ b/tiled/server/protocols.py @@ -1,5 +1,6 @@ +from abc import abstractmethod, ABC from dataclasses import dataclass -from typing import Protocol +from typing import Protocol, runtime_checkable from fastapi import Request @@ -10,13 +11,20 @@ class UserSessionState: user_name: str state: dict = None + + +@runtime_checkable # Required to be a field on a BaseSettings +class Authenticator(Protocol): + ... -class UsernamePasswordAuthenticator(Protocol): - def authenticate(self, username: str, password: str) -> UserSessionState: +class PasswordAuthenticator(Authenticator, ABC): + @abstractmethod + def authenticate(self, username: str, password: str) -> UserSessionState | None: pass -class Authenticator(Protocol): - def authenticate(self, request: Request) -> UserSessionState: +class ExternalAuthenticator(Authenticator, ABC): + @abstractmethod + def authenticate(self, request: Request) -> UserSessionState | None: pass diff --git a/tiled/server/router.py b/tiled/server/router.py index 050baf62f..1d4850b3f 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -6,14 +6,13 @@ from datetime import datetime, timedelta from functools import partial from pathlib import Path -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import anyio from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Security from jmespath.exceptions import JMESPathError from json_merge_patch import merge as apply_merge_patch from jsonpatch import apply_patch as apply_json_patch -from pydantic_settings import BaseSettings from starlette.status import ( HTTP_200_OK, HTTP_206_PARTIAL_CONTENT, @@ -26,12 +25,14 @@ HTTP_422_UNPROCESSABLE_ENTITY, ) +from tiled.server.protocols import Authenticator, ExternalAuthenticator, PasswordAuthenticator + from .. import __version__ from ..structures.core import Spec, StructureFamily from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri from ..validation_registration import ValidationError from . import schemas -from .authentication import Mode, get_authenticators, get_current_principal +from .authentication import get_authenticators, get_current_principal from .core import ( DEFAULT_PAGE_SIZE, DEPTH_LIMIT, @@ -61,7 +62,7 @@ ) from .file_response_with_range import FileResponseWithRange from .links import links_for_node -from .settings import get_settings +from .settings import Settings, get_settings from .utils import filter_for_access, get_base_url, record_timing router = APIRouter() @@ -70,8 +71,8 @@ @router.get("/", response_model=schemas.About) async def about( request: Request, - settings: BaseSettings = Depends(get_settings), - authenticators=Depends(get_authenticators), + settings: Settings = Depends(get_settings), + authenticators: Dict[str, Authenticator] = Depends(get_authenticators), serialization_registry=Depends(get_serialization_registry), query_registry=Depends(get_query_registry), # This dependency is here because it runs the code that moves @@ -90,10 +91,9 @@ async def about( } provider_specs = [] for provider, authenticator in authenticators.items(): - if authenticator.mode == Mode.password: + if isinstance(authenticator, PasswordAuthenticator): spec = { "provider": provider, - "mode": authenticator.mode.value, "links": { "auth_endpoint": f"{base_url}/auth/provider/{provider}/token" }, @@ -101,10 +101,9 @@ async def about( authenticator, "confirmation_message", None ), } - elif authenticator.mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): spec = { "provider": provider, - "mode": authenticator.mode.value, "links": { "auth_endpoint": f"{base_url}/auth/provider/{provider}/authorize" }, @@ -372,7 +371,7 @@ async def array_block( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a chunk of array-like data. @@ -450,7 +449,7 @@ async def array_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of array-like data. @@ -513,7 +512,7 @@ async def get_table_partition( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a partition (continuous block of rows) from a DataFrame [GET route]. @@ -562,7 +561,7 @@ async def post_table_partition( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a partition (continuous block of rows) from a DataFrame [POST route]. @@ -587,7 +586,7 @@ async def table_partition( format: Optional[str], filename: Optional[str], serialization_registry, - settings: BaseSettings, + settings: Settings, ): """ Fetch a partition (continuous block of rows) from a DataFrame. @@ -644,7 +643,7 @@ async def get_table_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch the data for the given table [GET route]. @@ -672,7 +671,7 @@ async def post_table_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch the data for the given table [POST route]. @@ -695,7 +694,7 @@ async def table_full( format: Optional[str], filename: Optional[str], serialization_registry, - settings: BaseSettings, + settings: Settings, ): """ Fetch the data for the given table. @@ -857,7 +856,7 @@ async def node_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch the data below the given node. @@ -923,7 +922,7 @@ async def get_awkward_buffers( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of AwkwardArray data. @@ -960,7 +959,7 @@ async def post_awkward_buffers( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of AwkwardArray data. @@ -990,7 +989,7 @@ async def _awkward_buffers( format: Optional[str], filename: Optional[str], serialization_registry, - settings: BaseSettings, + settings: Settings, ): structure_family = entry.structure_family structure = entry.structure() @@ -1041,7 +1040,7 @@ async def awkward_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of AwkwardArray data. @@ -1087,7 +1086,7 @@ async def post_metadata( path: str, body: schemas.PostMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata", "create"]), ): for data_source in body.data_sources: @@ -1117,7 +1116,7 @@ async def post_register( path: str, body: schemas.PostMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata", "create", "register"]), ): return await _create_node( @@ -1135,7 +1134,7 @@ async def _create_node( path: str, body: schemas.PostMetadataRequest, validation_registry, - settings: BaseSettings, + settings: Settings, entry, ): metadata, structure_family, specs = ( @@ -1186,7 +1185,7 @@ async def put_data_source( path: str, data_source: int, body: schemas.PutDataSourceRequest, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata", "register"]), ): await entry.put_data_source( @@ -1407,7 +1406,7 @@ async def patch_metadata( request: Request, body: schemas.PatchMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata"]), ): if not hasattr(entry, "replace_metadata"): @@ -1470,7 +1469,7 @@ async def put_metadata( request: Request, body: schemas.PutMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata"]), ): if not hasattr(entry, "replace_metadata"): @@ -1560,7 +1559,7 @@ async def get_asset( id: int, relative_path: Optional[Path] = None, entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): if not settings.expose_raw_assets: raise HTTPException( @@ -1658,7 +1657,7 @@ async def get_asset_manifest( request: Request, id: int, entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): if not settings.expose_raw_assets: raise HTTPException( @@ -1705,7 +1704,7 @@ async def validate_metadata( structure, specs: List[Spec], validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): metadata_modified = False diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index fa7f039e9..05cf7c827 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -272,14 +272,8 @@ class DeviceCode(pydantic.BaseModel): grant_type: str -class AuthenticationMode(str, enum.Enum): - password = "password" - external = "external" - - class AboutAuthenticationProvider(pydantic.BaseModel): provider: str - mode: AuthenticationMode links: Dict[str, str] confirmation_message: Optional[str] = None diff --git a/tiled/server/settings.py b/tiled/server/settings.py index e68c9b283..535052784 100644 --- a/tiled/server/settings.py +++ b/tiled/server/settings.py @@ -7,6 +7,8 @@ from pydantic_settings import BaseSettings +from tiled.server.protocols import Authenticator + DatabaseSettings = collections.namedtuple( "DatabaseSettings", "uri pool_size pool_pre_ping max_overflow" ) @@ -20,7 +22,7 @@ class Settings(BaseSettings): allow_origins: List[str] = [ item for item in os.getenv("TILED_ALLOW_ORIGINS", "").split() if item ] - authenticator: Any = None + authenticator: Authenticator | None = None # These 'single user' settings are only applicable if authenticator is None. single_user_api_key: str = os.getenv( "TILED_SINGLE_USER_API_KEY", secrets.token_hex(32) diff --git a/tiled/server/utils.py b/tiled/server/utils.py index d28611471..4eabd62b2 100644 --- a/tiled/server/utils.py +++ b/tiled/server/utils.py @@ -1,5 +1,8 @@ import contextlib import time +from typing import Dict + +from tiled.server.protocols import Authenticator from ..access_policies import NO_ACCESS from ..adapters.mapping import MapAdapter @@ -10,7 +13,7 @@ CSRF_COOKIE_NAME = "tiled_csrf" -def get_authenticators(): +def get_authenticators() -> Dict[str, Authenticator]: raise NotImplementedError( "This should be overridden via dependency_overrides. " "See tiled.server.app.build_app()."