From ff14dd1e8f7d449f75cb6da084384847c1326ba9 Mon Sep 17 00:00:00 2001 From: Joseph Ware <53935796+DiamondJoseph@users.noreply.github.com> Date: Fri, 24 Jan 2025 22:01:15 +0000 Subject: [PATCH] Move Schemas for describing server configuration (#862) * Move Schemas for describing server configuration * Update CHANGELOG --------- Co-authored-by: Dan Allan --- CHANGELOG.md | 7 +++++ tiled/_tests/test_pickle.py | 2 +- tiled/client/base.py | 8 ++++-- tiled/client/constructors.py | 6 ++-- tiled/client/context.py | 56 +++++++++++++++++++++--------------- tiled/schemas.py | 35 ++++++++++++++++++++++ tiled/server/router.py | 6 ++-- tiled/server/schemas.py | 37 ------------------------ 8 files changed, 88 insertions(+), 69 deletions(-) create mode 100644 tiled/schemas.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 223f09a6d..84fb89c53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,13 @@ Write the date in place of the "Unreleased" in the case a new version is release # Changelog +## Unreleased + +### Maintenance + +- Make depedencies shared by client and server into core dependencies. +- Use schemas for describing server configuration on the client side too. + ## v0.1.0-b16 (2024-01-23) ### Maintenance diff --git a/tiled/_tests/test_pickle.py b/tiled/_tests/test_pickle.py index aee17896a..3e4bb4215 100644 --- a/tiled/_tests/test_pickle.py +++ b/tiled/_tests/test_pickle.py @@ -32,7 +32,7 @@ def test_pickle_clients(structure_clients, tmpdir): raise pytest.skip(f"Could not connect to {API_URL}") cache = Cache(tmpdir / "http_response_cache.db") with Context(API_URL, cache=cache) as context: - if parse(context.server_info["library_version"]) < parse(MIN_VERSION): + if parse(context.server_info.library_version) < parse(MIN_VERSION): raise pytest.skip( f"Server at {API_URL} is running too old a version to test against." ) diff --git a/tiled/client/base.py b/tiled/client/base.py index 6c9ee3d82..bc9aedb5e 100644 --- a/tiled/client/base.py +++ b/tiled/client/base.py @@ -10,6 +10,8 @@ import orjson from httpx import URL +from tiled.client.context import Context + from ..structures.core import STRUCTURE_TYPES, Spec, StructureFamily from ..structures.data_source import DataSource from ..utils import UNCHANGED, DictView, ListView, patch_mimetypes, safe_json_dump @@ -112,7 +114,7 @@ class BaseClient: def __init__( self, - context, + context: Context, *, item, structure_clients, @@ -415,9 +417,9 @@ def formats(self): "List formats that the server can export this data as." formats = set() for spec in self.item["attributes"]["specs"]: - formats.update(self.context.server_info["formats"].get(spec["name"], [])) + formats.update(self.context.server_info.formats.get(spec["name"], [])) formats.update( - self.context.server_info["formats"][ + self.context.server_info.formats[ self.item["attributes"]["structure_family"] ] ) diff --git a/tiled/client/constructors.py b/tiled/client/constructors.py index 59b881eaa..1b8c8fa1c 100644 --- a/tiled/client/constructors.py +++ b/tiled/client/constructors.py @@ -96,7 +96,7 @@ def from_uri( def from_context( - context, + context: Context, structure_clients="numpy", node_path_parts=None, include_data_sources=False, @@ -126,8 +126,8 @@ def from_context( # 2. If there are cached valid credentials for this server, use them. # 3. If not, and the server requires authentication, prompt for authentication. if context.api_key is None: - auth_is_required = context.server_info["authentication"]["required"] - has_providers = len(context.server_info["authentication"]["providers"]) > 0 + auth_is_required = context.server_info.authentication.required + has_providers = len(context.server_info.authentication.providers) > 0 if auth_is_required and not has_providers: raise RuntimeError( """This server requires API key authentication. diff --git a/tiled/client/context.py b/tiled/client/context.py index a1be82d08..46d401dff 100644 --- a/tiled/client/context.py +++ b/tiled/client/context.py @@ -6,10 +6,14 @@ import urllib.parse import warnings from pathlib import Path +from typing import List from urllib.parse import parse_qs, urlparse import httpx import platformdirs +from pydantic import TypeAdapter + +from tiled.schemas import About, AboutAuthenticationProvider from .._version import __version__ as tiled_version from ..utils import UNSET, DictView, parse_time_string @@ -42,11 +46,13 @@ def raise_if_cannot_prompt(): ) -def identity_provider_input(providers, provider=None): +def identity_provider_input( + providers: List[AboutAuthenticationProvider], +) -> AboutAuthenticationProvider: while True: - print("Authenticaiton providers:") + print("Authentication providers:") for i, spec in enumerate(providers, start=1): - print(f"{i} - {spec['provider']}") + print(f"{i} - {spec.provider}") raw_choice = input( "Choose an authentication provider (or press Enter to cancel): " ) @@ -81,18 +87,18 @@ class PasswordRejected(RuntimeError): pass -def prompt_for_credentials(http_client, providers): +def prompt_for_credentials(http_client, providers: List[AboutAuthenticationProvider]): """ Prompt for credentials or third-party login at an interactive terminal. """ if len(providers) == 1: # There is only one choice, so no need to prompt the user. - (spec,) = providers + spec = providers[0] else: spec = identity_provider_input(providers) - auth_endpoint = spec["links"]["auth_endpoint"] - provider = spec["provider"] - mode = spec["mode"] + auth_endpoint = spec.links["auth_endpoint"] + provider = spec.provider + mode = spec.mode if mode == "password": # Prompt for username, password at terminal. username = username_input() @@ -124,7 +130,7 @@ def prompt_for_credentials(http_client, providers): tokens = device_code_grant(http_client, auth_endpoint) else: raise ValueError(f"Server has unknown authentication mechanism {mode!r}") - confirmation_message = spec.get("confirmation_message") + confirmation_message = spec.confirmation_message if confirmation_message: username = tokens["identity"]["id"] print(confirmation_message.format(id=username)) @@ -252,7 +258,7 @@ def __init__( # (2) Let the server set the CSRF cookie. # No authentication has been set up yet, so these requests will be unauthenticated. # https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#double-submit-cookie - self.server_info = handle_error( + server_info = handle_error( self.http_client.get( self.api_uri, headers={ @@ -261,6 +267,7 @@ def __init__( }, ) ).json() + self.server_info: About = TypeAdapter(About).validate_python(server_info) self.api_key = api_key # property setter sets Authorization header self.admin = Admin(self) # accessor for admin-related requests @@ -270,7 +277,7 @@ def __repr__(self): auth_info.append("(unauthenticated)") else: auth_info.append("authenticated") - if self.server_info["authentication"].get("links"): + if self.server_info.authentication.links: whoami = self.whoami() auth_info.append("as") if whoami["type"] == "service": @@ -435,7 +442,7 @@ def from_app( raise_server_exceptions=raise_server_exceptions, ) if api_key is UNSET: - if not context.server_info["authentication"]["providers"]: + if not context.server_info.authentication.providers: # This is a single-user server. # Extract the API key from the app and set it. from ..server.settings import get_settings @@ -487,7 +494,7 @@ def which_api_key(self): raise RuntimeError("Not API key is configured for the client.") return handle_error( self.http_client.get( - self.server_info["authentication"]["links"]["apikey"], + self.server_info.authentication.links.apikey, headers={"Accept": MSGPACK_MIME_TYPE}, ) ).json() @@ -516,7 +523,7 @@ def create_api_key(self, scopes=None, expires_in=None, note=None): expires_in = parse_time_string(expires_in) return handle_error( self.http_client.post( - self.server_info["authentication"]["links"]["apikey"], + self.server_info.authentication.links.apikey, headers={"Accept": MSGPACK_MIME_TYPE}, json={"scopes": scopes, "expires_in": expires_in, "note": note}, ) @@ -536,7 +543,7 @@ def revoke_api_key(self, first_eight): Identify the API key to be deleted by passing its first 8 characters. (Any additional characters passed will be truncated.) """ - url_path = self.server_info["authentication"]["links"]["apikey"] + url_path = self.server_info.authentication.links.apikey handle_error( self.http_client.delete( url_path, @@ -605,8 +612,11 @@ def authenticate( Next time, try to automatically authenticate using this session. """ # Obtain tokens via OAuth2 unless the caller has passed them. - providers = self.server_info["authentication"]["providers"] - tokens = prompt_for_credentials(self.http_client, providers) + providers = self.server_info.authentication.providers + tokens = prompt_for_credentials( + self.http_client, + providers, + ) self.configure_auth(tokens, remember_me=remember_me) # These two methods are aliased for convenience. @@ -628,7 +638,7 @@ def configure_auth(self, tokens, remember_me=True): ) # Configure an httpx.Auth instance on the http_client, which # will manage refreshing the tokens as needed. - refresh_url = self.server_info["authentication"]["links"]["refresh_session"] + refresh_url = self.server_info.authentication.links.refresh_session csrf_token = self.http_client.cookies["tiled_csrf"] if remember_me: token_directory = self._token_directory() @@ -671,7 +681,7 @@ def use_cached_tokens(self): success : bool Indicating whether valid cached tokens were found """ - refresh_url = self.server_info["authentication"]["links"]["refresh_session"] + refresh_url = self.server_info.authentication.links.refresh_session csrf_token = self.http_client.cookies["tiled_csrf"] # Try automatically authenticating using cached tokens, if any. @@ -731,7 +741,7 @@ def whoami(self): "Return information about the currently-authenticated user or service." return handle_error( self.http_client.get( - self.server_info["authentication"]["links"]["whoami"], + self.server_info.authentication.links.whoami, headers={"Accept": MSGPACK_MIME_TYPE}, ) ).json() @@ -774,7 +784,7 @@ def revoke_session(self, session_id): """ handle_error( self.http_client.delete( - self.server_info["authentication"]["links"]["revoke_session"].format( + self.server_info.authentication.links.revoke_session.format( session_id=session_id ), headers={"x-csrf": self.http_client.cookies["tiled_csrf"]}, @@ -785,9 +795,9 @@ def revoke_session(self, session_id): class Admin: "Accessor for requests that require administrative privileges." - def __init__(self, context): + def __init__(self, context: Context): self.context = context - self.base_url = context.server_info["links"]["self"] + self.base_url = context.server_info.links["self"] def list_principals(self, offset=0, limit=100): "List Principals (users and services) in the authenticaiton database." diff --git a/tiled/schemas.py b/tiled/schemas.py new file mode 100644 index 000000000..256b1100e --- /dev/null +++ b/tiled/schemas.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel + + +class AboutAuthenticationProvider(BaseModel): + provider: str + mode: Literal["password", "external"] + links: Dict[str, str] + confirmation_message: Optional[str] = None + + +class AboutAuthenticationLinks(BaseModel): + whoami: str + apikey: str + refresh_session: str + revoke_session: str + logout: str + + +class AboutAuthentication(BaseModel): + required: bool + providers: List[AboutAuthenticationProvider] + links: Optional[AboutAuthenticationLinks] = None + + +class About(BaseModel): + api_version: int + library_version: str + formats: Dict[str, List[str]] + aliases: Dict[str, Dict[str, List[str]]] + queries: List[str] + authentication: AboutAuthentication + links: Dict[str, str] + meta: Dict[str, Any] diff --git a/tiled/server/router.py b/tiled/server/router.py index ac8e9799d..97e6dd84c 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -26,6 +26,8 @@ HTTP_422_UNPROCESSABLE_ENTITY, ) +from tiled.schemas import About + from .. import __version__ from ..structures.core import Spec, StructureFamily from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri @@ -67,7 +69,7 @@ router = APIRouter() -@router.get("/", response_model=schemas.About) +@router.get("/", response_model=About) async def about( request: Request, settings: BaseSettings = Depends(get_settings), @@ -130,7 +132,7 @@ async def about( return json_or_msgpack( request, - schemas.About( + About( library_version=__version__, api_version=0, formats={ diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index fa7f039e9..684db5b8c 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -272,43 +272,6 @@ class DeviceCode(pydantic.BaseModel): grant_type: str -class AuthenticationMode(str, enum.Enum): - password = "password" - external = "external" - - -class AboutAuthenticationProvider(pydantic.BaseModel): - provider: str - mode: AuthenticationMode - links: Dict[str, str] - confirmation_message: Optional[str] = None - - -class AboutAuthenticationLinks(pydantic.BaseModel): - whoami: str - apikey: str - refresh_session: str - revoke_session: str - logout: str - - -class AboutAuthentication(pydantic.BaseModel): - required: bool - providers: List[AboutAuthenticationProvider] - links: Optional[AboutAuthenticationLinks] = None - - -class About(pydantic.BaseModel): - api_version: int - library_version: str - formats: Dict[str, List[str]] - aliases: Dict[str, Dict[str, List[str]]] - queries: List[str] - authentication: AboutAuthentication - links: Dict[str, str] - meta: Dict - - class PrincipalType(str, enum.Enum): user = "user" service = "service"