Skip to content

Commit

Permalink
Refactor Authenticators to make use of inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Jan 23, 2025
1 parent 1f5779d commit a04d11a
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 67 deletions.
7 changes: 3 additions & 4 deletions example_configs/external_service/custom.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import numpy

from tiled.adapters.array import ArrayAdapter
from tiled.authenticators import Mode, UserSessionState
from tiled.authenticators import UserSessionState
from tiled.server.protocols import InternalAuthenticator
from tiled.structures.core import StructureFamily


class Authenticator:
class Authenticator(InternalAuthenticator):
"This accepts any password and stashes it in session state as 'token'."
mode = Mode.password

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {"token": password})

Expand Down
65 changes: 28 additions & 37 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,43 @@
import re
import secrets
from collections.abc import Iterable
from typing import Any, cast
from typing import Any, Mapping, cast

import httpx
from fastapi import APIRouter, Request
from jose import JWTError, jwt
from pydantic import Secret
from starlette.responses import RedirectResponse

from .server.authentication import Mode
from .server.protocols import UserSessionState
from .server.protocols import ExternalAuthenticator, InternalAuthenticator, UserSessionState
from .server.utils import get_root_url
from .utils import modules_available

logger = logging.getLogger(__name__)


class DummyAuthenticator:
class DummyAuthenticator(InternalAuthenticator):
"""
For test and demo purposes only!
Accept any username and any password.
"""

mode = Mode.password

def __init__(self, confirmation_message=""):
def __init__(self, confirmation_message: str = ""):
self.confirmation_message = confirmation_message

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class DictionaryAuthenticator:
class DictionaryAuthenticator(InternalAuthenticator):
"""
For test and demo purposes only!
Check passwords from a dictionary of usernames mapped to passwords.
"""

mode = Mode.password
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand All @@ -61,11 +57,11 @@ class DictionaryAuthenticator:
description: May be displayed by client after successful login.
"""

def __init__(self, users_to_passwords, confirmation_message=""):
def __init__(self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""):
self._users_to_passwords = users_to_passwords
self.confirmation_message = confirmation_message

async def authenticate(self, username: str, password: str) -> UserSessionState:
async def authenticate(self, username: str, password: str) -> UserSessionState | None:
true_password = self._users_to_passwords.get(username)
if not true_password:
# Username is not valid.
Expand All @@ -74,8 +70,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class PAMAuthenticator:
mode = Mode.password
class PAMAuthenticator(InternalAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand All @@ -89,7 +84,7 @@ class PAMAuthenticator:
description: May be displayed by client after successful login.
"""

def __init__(self, service="login", confirmation_message=""):
def __init__(self, service: str = "login", confirmation_message: str = ""):
if not modules_available("pamela"):
raise ModuleNotFoundError(
"This PAMAuthenticator requires the module 'pamela' to be installed."
Expand All @@ -98,20 +93,18 @@ def __init__(self, service="login", confirmation_message=""):
self.confirmation_message = confirmation_message
# TODO Try to open a PAM session.

async def authenticate(self, username: str, password: str) -> UserSessionState:
async def authenticate(self, username: str, password: str) -> UserSessionState | None:
import pamela

try:
pamela.authenticate(username, password, service=self.service)
return UserSessionState(username, {})
except pamela.PAMError:
# Authentication failed.
return
else:
return UserSessionState(username, {})


class OIDCAuthenticator:
mode = Mode.external
class OIDCAuthenticator(ExternalAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -178,7 +171,7 @@ def authorization_endpoint(self) -> httpx.URL:
cast(str, self._config_from_oidc_url.get("authorization_endpoint"))
)

async def authenticate(self, request: Request) -> UserSessionState:
async def authenticate(self, request: Request) -> UserSessionState | None:
code = request.query_params["code"]
# A proxy in the middle may make the request into something like
# 'http://localhost:8000/...' so we fix the first part but keep
Expand Down Expand Up @@ -216,11 +209,13 @@ async def authenticate(self, request: Request) -> UserSessionState:
return UserSessionState(verified_body["sub"], {})


class KeyNotFoundError(Exception):
pass


async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri):
async def exchange_code(
token_uri: str,
auth_code: str,
client_id: str,
client_secret: str,
redirect_uri: str
) -> httpx.Response:
"""Method that talks to an IdP to exchange a code for an access_token and/or id_token
Args:
token_url ([type]): [description]
Expand All @@ -241,14 +236,13 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect
return response


class SAMLAuthenticator:
mode = Mode.external
class SAMLAuthenticator(ExternalAuthenticator):

def __init__(
self,
saml_settings, # See EXAMPLE_SAML_SETTINGS below.
attribute_name, # which SAML attribute to use as 'id' for Idenity
confirmation_message="",
attribute_name: str, # which SAML attribute to use as 'id' for Idenity
confirmation_message: str = "",
):
self.saml_settings = saml_settings
self.attribute_name = attribute_name
Expand All @@ -268,7 +262,7 @@ def __init__(
from onelogin.saml2.auth import OneLogin_Saml2_Auth

@router.get("/login")
async def saml_login(request: Request):
async def saml_login(request: httpx.Request) -> RedirectResponse:
req = await prepare_saml_from_fastapi_request(request)
auth = OneLogin_Saml2_Auth(req, self.saml_settings)
# saml_settings = auth.get_settings()
Expand All @@ -279,12 +273,11 @@ async def saml_login(request: Request):
# else:
# print("Error found on Metadata: %s" % (', '.join(errors)))
callback_url = auth.login()
response = RedirectResponse(url=callback_url)
return response
return RedirectResponse(url=callback_url)

self.include_routers = [router]

async def authenticate(self, request) -> UserSessionState:
async def authenticate(self, request: Request) -> UserSessionState | None:
if not modules_available("onelogin"):
raise ModuleNotFoundError(
"This SAMLAuthenticator requires the module 'oneline' to be installed."
Expand All @@ -310,7 +303,7 @@ async def authenticate(self, request) -> UserSessionState:
return None


async def prepare_saml_from_fastapi_request(request, debug=False):
async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, str]:
form_data = await request.form()
rv = {
"http_host": request.client.host,
Expand All @@ -336,7 +329,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False):
return rv


class LDAPAuthenticator:
class LDAPAuthenticator(InternalAuthenticator):
"""
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
The parameter ``use_tls`` was added for convenience of testing.
Expand Down Expand Up @@ -519,8 +512,6 @@ class LDAPAuthenticator:
id: user02
"""

mode = Mode.password

def __init__(
self,
server_address,
Expand Down
10 changes: 5 additions & 5 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
HTTP_500_INTERNAL_SERVER_ERROR,
)

from ..authenticators import Mode
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator

from ..config import construct_build_app_kwargs
from ..media_type_registration import (
compression_registry as default_compression_registry,
Expand Down Expand Up @@ -384,12 +385,11 @@ async def unhandled_exception_handler(
for spec in authentication["providers"]:
provider = spec["provider"]
authenticator = spec["authenticator"]
mode = authenticator.mode
if mode == Mode.password:
if isinstance(authenticator, InternalAuthenticator):
authentication_router.post(f"/provider/{provider}/token")(
build_handle_credentials_route(authenticator, provider)
)
elif mode == Mode.external:
elif isinstance(authenticator, ExternalAuthenticator):
# Client starts here to create a PendingSession.
authentication_router.post(f"/provider/{provider}/authorize")(
build_device_code_authorize_route(authenticator, provider)
Expand All @@ -414,7 +414,7 @@ async def unhandled_exception_handler(
# build_auth_code_route(authenticator, provider)
# )
else:
raise ValueError(f"unknown authentication mode {mode}")
raise ValueError(f"unknown authenticator type {type(authenticator)}")
for custom_router in getattr(authenticator, "include_routers", []):
authentication_router.include_router(
custom_router, prefix=f"/provider/{provider}"
Expand Down
10 changes: 3 additions & 7 deletions tiled/server/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from ..utils import SHARE_TILED_PATH, SpecialUsers
from . import schemas
from .core import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, json_or_msgpack
from .protocols import UsernamePasswordAuthenticator, UserSessionState
from .protocols import InternalAuthenticator, UserSessionState
from .settings import get_settings
from .utils import API_KEY_COOKIE_NAME, get_authenticators, get_base_url

Expand All @@ -86,11 +86,6 @@ def utcnow():
return datetime.now(timezone.utc).replace(microsecond=0)


class Mode(enum.Enum):
password = "password"
external = "external"


class Token(BaseModel):
access_token: str
token_type: str
Expand Down Expand Up @@ -711,7 +706,8 @@ async def route(


def build_handle_credentials_route(
authenticator: UsernamePasswordAuthenticator, provider
authenticator: InternalAuthenticator,
provider
):
"Register a handle_credentials route function for this Authenticator."

Expand Down
17 changes: 8 additions & 9 deletions tiled/server/protocols.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from dataclasses import dataclass
from typing import Protocol

from abc import ABC
from fastapi import Request


@dataclass
class UserSessionState:
"""Data transfer class to communicate custom session state infromation."""
"""Data transfer class to communicate custom session state information."""

user_name: str
state: dict = None


class UsernamePasswordAuthenticator(Protocol):
def authenticate(self, username: str, password: str) -> UserSessionState:
pass
class InternalAuthenticator(ABC):
def authenticate(self, username: str, password: str) -> UserSessionState | None:
raise NotImplemented


class Authenticator(Protocol):
def authenticate(self, request: Request) -> UserSessionState:
pass
class ExternalAuthenticator(ABC):
def authenticate(self, request: Request) -> UserSessionState | None:
raise NotImplemented
10 changes: 5 additions & 5 deletions tiled/server/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
HTTP_422_UNPROCESSABLE_ENTITY,
)

from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator

from .. import __version__
from ..structures.core import Spec, StructureFamily
from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri
from ..validation_registration import ValidationError
from . import schemas
from .authentication import Mode, get_authenticators, get_current_principal
from .authentication import get_authenticators, get_current_principal
from .core import (
DEFAULT_PAGE_SIZE,
DEPTH_LIMIT,
Expand Down Expand Up @@ -90,21 +92,19 @@ async def about(
}
provider_specs = []
for provider, authenticator in authenticators.items():
if authenticator.mode == Mode.password:
if isinstance(authenticator, InternalAuthenticator):
spec = {
"provider": provider,
"mode": authenticator.mode.value,
"links": {
"auth_endpoint": f"{base_url}/auth/provider/{provider}/token"
},
"confirmation_message": getattr(
authenticator, "confirmation_message", None
),
}
elif authenticator.mode == Mode.external:
elif isinstance(authenticator, ExternalAuthenticator):
spec = {
"provider": provider,
"mode": authenticator.mode.value,
"links": {
"auth_endpoint": f"{base_url}/auth/provider/{provider}/authorize"
},
Expand Down

0 comments on commit a04d11a

Please sign in to comment.