Skip to content

Commit

Permalink
Refactoring and type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Jan 14, 2025
1 parent cd69a67 commit 7a919fa
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 137 deletions.
27 changes: 9 additions & 18 deletions example_configs/external_service/custom.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,36 @@
import numpy
from pydantic import Secret

from tiled.adapters.array import ArrayAdapter
from tiled.authenticators import Mode, UserSessionState
from tiled.structures.core import StructureFamily


class Authenticator:
"This accepts any password and stashes it in session state as 'token'."
mode = Mode.password

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {"token": password})


# This stands in for a secret token issued by the external service.
SERVICE_ISSUED_TOKEN = "secret"


class MockClient:
def __init__(self, base_url):
self.base_url = base_url

def __init__(self, base_url: str, example_token: str = "secret"):
self._base_url = base_url
# This stands in for a secret token issued by the external service.
self._example_token = Secret(example_token)

# This API (get_contents, get_metadata, get_data) is just made up and not important.
# Could be anything.

async def get_metadata(self, url, token):
# This assert stands in for the mocked service
# authenticating a request.
assert token == SERVICE_ISSUED_TOKEN
assert token == self._example_token.get_secret_value()
return {"metadata": str(url)}

async def get_contents(self, url, token):
# This assert stands in for the mocked service
# authenticating a request.
assert token == SERVICE_ISSUED_TOKEN
assert token == self._example_token.get_secret_value()
return ["a", "b", "c"]

async def get_data(self, url, token):
# This assert stands in for the mocked service
# authenticating a request.
assert token == SERVICE_ISSUED_TOKEN
assert token == self._example_token.get_secret_value()
return numpy.ones((3, 3))


Expand Down
10 changes: 6 additions & 4 deletions example_configs/mock-oidc-server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ authentication:
client_secret: secret
well_known_uri: http://localhost:8080/.well-known/openid-configuration
trees:
# Just some arbitrary example data...
# The point of this example is the authenticaiton above.
- tree: tiled.examples.generated_minimal:tree
path: /
- path: /
tree: catalog
args:
uri: "sqlite+aiosqlite:///:memory:"
writable_storage: "/tmp/data"
init_if_not_exists: true
31 changes: 10 additions & 21 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,37 @@

import httpx
from fastapi import APIRouter, Request
from jose import JWTError, jwk, jwt
from jose import JWTError, jwt
from pydantic import Secret
from starlette.responses import RedirectResponse

from .server.authentication import Mode
from .server.protocols import UserSessionState
from .server.protocols import ExternalAuthenticator, UserSessionState, PasswordAuthenticator
from .server.utils import get_root_url
from .utils import modules_available

logger = logging.getLogger(__name__)


class DummyAuthenticator:
class DummyAuthenticator(PasswordAuthenticator):
"""
For test and demo purposes only!
Accept any username and any password.
"""

mode = Mode.password

def __init__(self, confirmation_message=""):
self.confirmation_message = confirmation_message

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class DictionaryAuthenticator:
class DictionaryAuthenticator(PasswordAuthenticator):
"""
For test and demo purposes only!
Check passwords from a dictionary of usernames mapped to passwords.
"""

mode = Mode.password
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -74,8 +68,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class PAMAuthenticator:
mode = Mode.password
class PAMAuthenticator(PasswordAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -110,8 +103,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class OIDCAuthenticator:
mode = Mode.external
class OIDCAuthenticator(ExternalAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -164,7 +156,7 @@ def jwks_uri(self) -> str:
def token_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("token_endpoint"))

async def authenticate(self, request: Request) -> UserSessionState:
async def authenticate(self, request: Request) -> UserSessionState | None:
code = request.query_params["code"]
# A proxy in the middle may make the request into something like
# 'http://localhost:8000/...' so we fix the first part but keep
Expand Down Expand Up @@ -228,8 +220,7 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect
return response


class SAMLAuthenticator:
mode = Mode.external
class SAMLAuthenticator(ExternalAuthenticator):

def __init__(
self,
Expand Down Expand Up @@ -271,7 +262,7 @@ async def saml_login(request: Request):

self.include_routers = [router]

async def authenticate(self, request) -> UserSessionState:
async def authenticate(self, request) -> UserSessionState | None:
if not modules_available("onelogin"):
raise ModuleNotFoundError(
"This SAMLAuthenticator requires the module 'oneline' to be installed."
Expand Down Expand Up @@ -323,7 +314,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False):
return rv


class LDAPAuthenticator:
class LDAPAuthenticator(PasswordAuthenticator):
"""
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
The parameter ``use_tls`` was added for convenience of testing.
Expand Down Expand Up @@ -506,8 +497,6 @@ class LDAPAuthenticator:
id: user02
"""

mode = Mode.password

def __init__(
self,
server_address,
Expand Down
4 changes: 2 additions & 2 deletions tiled/authn_database/connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from ..server.settings import get_settings
from ..server.settings import Settings, get_settings
from ..utils import ensure_specified_sql_driver

# A given process probably only has one of these at a time, but we
Expand Down Expand Up @@ -31,7 +31,7 @@ async def close_database_connection_pool(database_settings):
await engine.dispose()


async def get_database_engine(settings=Depends(get_settings)):
async def get_database_engine(settings: Settings = Depends(get_settings)):
# Special case for single-user mode
if settings.database_uri is None:
return None
Expand Down
7 changes: 4 additions & 3 deletions tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from urllib.parse import parse_qs, urlparse

from fastapi import FastAPI
import httpx
import platformdirs

Expand Down Expand Up @@ -414,7 +415,7 @@ def from_any_uri(
@classmethod
def from_app(
cls,
app,
app: FastAPI,
*,
cache=UNSET,
headers=None,
Expand All @@ -438,9 +439,9 @@ def from_app(
if not context.server_info["authentication"]["providers"]:
# This is a single-user server.
# Extract the API key from the app and set it.
from ..server.settings import get_settings
from ..server.settings import get_settings, Settings

settings = app.dependency_overrides[get_settings]()
settings: Settings = app.dependency_overrides[get_settings]()
api_key = settings.single_user_api_key or None
else:
# This is a multi-user server but no API key was passed,
Expand Down
50 changes: 26 additions & 24 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import urllib.parse
import warnings
from contextlib import asynccontextmanager
from functools import lru_cache, partial
from functools import cache, partial
from pathlib import Path
from typing import List
from typing import Any, Dict, List

import anyio
import packaging.version
Expand All @@ -34,7 +34,8 @@
HTTP_500_INTERNAL_SERVER_ERROR,
)

from ..authenticators import Mode
from tiled.server.protocols import Authenticator, ExternalAuthenticator, PasswordAuthenticator

from ..config import construct_build_app_kwargs
from ..media_type_registration import (
compression_registry as default_compression_registry,
Expand All @@ -51,7 +52,7 @@
get_validation_registry,
)
from .router import distinct, patch_route_signature, router, search
from .settings import get_settings
from .settings import Settings, get_settings
from .utils import (
API_KEY_COOKIE_NAME,
CSRF_COOKIE_NAME,
Expand Down Expand Up @@ -81,7 +82,7 @@
current_principal = contextvars.ContextVar("current_principal")


def custom_openapi(app: FastAPI):
def custom_openapi(app: FastAPI) -> Dict[str, Any]:
"""
The app's openapi method will be monkey-patched with this.
Expand Down Expand Up @@ -118,7 +119,7 @@ def build_app(
validation_registry=None,
tasks=None,
scalable=False,
):
) -> FastAPI:
"""
Serve a Tree
Expand All @@ -133,7 +134,7 @@ def build_app(
Dict of other server configuration.
"""
authentication = authentication or {}
authenticators = {
authenticators: Dict[str, Authenticator] = {
spec["provider"]: spec["authenticator"]
for spec in authentication.get("providers", [])
}
Expand Down Expand Up @@ -385,12 +386,11 @@ async def unhandled_exception_handler(
for spec in authentication["providers"]:
provider = spec["provider"]
authenticator = spec["authenticator"]
mode = authenticator.mode
if mode == Mode.password:
if isinstance(authenticator, PasswordAuthenticator):
authentication_router.post(f"/provider/{provider}/token")(
build_handle_credentials_route(authenticator, provider)
)
elif mode == Mode.external:
elif isinstance(authenticator, ExternalAuthenticator):
# Client starts here to create a PendingSession.
authentication_router.post(f"/provider/{provider}/authorize")(
build_device_code_authorize_route(authenticator, provider)
Expand All @@ -415,7 +415,7 @@ async def unhandled_exception_handler(
# build_auth_code_route(authenticator, provider)
# )
else:
raise ValueError(f"unknown authentication mode {mode}")
raise ValueError(f"Unexpected authenticator type {type(authenticator)}")
for custom_router in getattr(authenticator, "include_routers", []):
authentication_router.include_router(
custom_router, prefix=f"/provider/{provider}"
Expand All @@ -438,15 +438,16 @@ async def unhandled_exception_handler(
response_model=schemas.GetDistinctResponse,
)(patch_route_signature(distinct, query_registry))

@lru_cache(1)
def override_get_authenticators():
@cache
def override_get_authenticators() -> Dict[str, Authenticator]:
print(authenticators)
return authenticators

@lru_cache(1)
@cache
def override_get_root_tree():
return tree

@lru_cache(1)
@cache
def override_get_settings():
settings = get_settings()
for item in [
Expand Down Expand Up @@ -492,7 +493,7 @@ async def startup_event():

logger.info(f"Tiled version {__version__}")
# Validate the single-user API key.
settings = app.dependency_overrides[get_settings]()
settings: Settings = app.dependency_overrides[get_settings]()
single_user_api_key = settings.single_user_api_key
API_KEY_MSG = """
Here are two ways to generate a good API key:
Expand Down Expand Up @@ -772,14 +773,14 @@ async def set_cookies(request: Request, call_next):
app.dependency_overrides[get_settings] = override_get_settings
if query_registry is not None:

@lru_cache(1)
@cache
def override_get_query_registry():
return query_registry

app.dependency_overrides[get_query_registry] = override_get_query_registry
if serialization_registry is not None:

@lru_cache(1)
@cache
def override_get_serialization_registry():
return serialization_registry

Expand All @@ -789,7 +790,7 @@ def override_get_serialization_registry():

if validation_registry is not None:

@lru_cache(1)
@cache
def override_get_validation_registry():
return validation_registry

Expand Down Expand Up @@ -936,13 +937,14 @@ def __getattr__(name):


def print_admin_api_key_if_generated(
web_app: FastAPI, host: str, port: int, force: bool = False
web_app: FastAPI,
host: str = "127.0.0.1",
port: int = 8000,
force: bool = False
):
"Print message to stderr with API key if server-generated (or force=True)."
host = host or "127.0.0.1"
port = port or 8000
settings = web_app.dependency_overrides.get(get_settings, get_settings)()
authenticators = web_app.dependency_overrides.get(
settings: Settings = web_app.dependency_overrides.get(get_settings, get_settings)()
authenticators: Dict[str, Authenticator] = web_app.dependency_overrides.get(
get_authenticators, get_authenticators
)()
if settings.allow_anonymous_access:
Expand Down
Loading

0 comments on commit 7a919fa

Please sign in to comment.