diff --git a/tiled/_tests/test_pickle.py b/tiled/_tests/test_pickle.py index aee17896a..3e4bb4215 100644 --- a/tiled/_tests/test_pickle.py +++ b/tiled/_tests/test_pickle.py @@ -32,7 +32,7 @@ def test_pickle_clients(structure_clients, tmpdir): raise pytest.skip(f"Could not connect to {API_URL}") cache = Cache(tmpdir / "http_response_cache.db") with Context(API_URL, cache=cache) as context: - if parse(context.server_info["library_version"]) < parse(MIN_VERSION): + if parse(context.server_info.library_version) < parse(MIN_VERSION): raise pytest.skip( f"Server at {API_URL} is running too old a version to test against." ) diff --git a/tiled/client/base.py b/tiled/client/base.py index 6c9ee3d82..bc9aedb5e 100644 --- a/tiled/client/base.py +++ b/tiled/client/base.py @@ -10,6 +10,8 @@ import orjson from httpx import URL +from tiled.client.context import Context + from ..structures.core import STRUCTURE_TYPES, Spec, StructureFamily from ..structures.data_source import DataSource from ..utils import UNCHANGED, DictView, ListView, patch_mimetypes, safe_json_dump @@ -112,7 +114,7 @@ class BaseClient: def __init__( self, - context, + context: Context, *, item, structure_clients, @@ -415,9 +417,9 @@ def formats(self): "List formats that the server can export this data as." formats = set() for spec in self.item["attributes"]["specs"]: - formats.update(self.context.server_info["formats"].get(spec["name"], [])) + formats.update(self.context.server_info.formats.get(spec["name"], [])) formats.update( - self.context.server_info["formats"][ + self.context.server_info.formats[ self.item["attributes"]["structure_family"] ] ) diff --git a/tiled/client/constructors.py b/tiled/client/constructors.py index 59b881eaa..1b8c8fa1c 100644 --- a/tiled/client/constructors.py +++ b/tiled/client/constructors.py @@ -96,7 +96,7 @@ def from_uri( def from_context( - context, + context: Context, structure_clients="numpy", node_path_parts=None, include_data_sources=False, @@ -126,8 +126,8 @@ def from_context( # 2. If there are cached valid credentials for this server, use them. # 3. If not, and the server requires authentication, prompt for authentication. if context.api_key is None: - auth_is_required = context.server_info["authentication"]["required"] - has_providers = len(context.server_info["authentication"]["providers"]) > 0 + auth_is_required = context.server_info.authentication.required + has_providers = len(context.server_info.authentication.providers) > 0 if auth_is_required and not has_providers: raise RuntimeError( """This server requires API key authentication. diff --git a/tiled/client/context.py b/tiled/client/context.py index a1be82d08..63f7a35ab 100644 --- a/tiled/client/context.py +++ b/tiled/client/context.py @@ -10,6 +10,9 @@ import httpx import platformdirs +from pydantic import TypeAdapter + +from tiled.schemas import About from .._version import __version__ as tiled_version from ..utils import UNSET, DictView, parse_time_string @@ -124,7 +127,7 @@ def prompt_for_credentials(http_client, providers): tokens = device_code_grant(http_client, auth_endpoint) else: raise ValueError(f"Server has unknown authentication mechanism {mode!r}") - confirmation_message = spec.get("confirmation_message") + confirmation_message = spec.confirmation_message if confirmation_message: username = tokens["identity"]["id"] print(confirmation_message.format(id=username)) @@ -252,7 +255,7 @@ def __init__( # (2) Let the server set the CSRF cookie. # No authentication has been set up yet, so these requests will be unauthenticated. # https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#double-submit-cookie - self.server_info = handle_error( + server_info = handle_error( self.http_client.get( self.api_uri, headers={ @@ -261,6 +264,7 @@ def __init__( }, ) ).json() + self.server_info: About = TypeAdapter(About).validate_python(server_info) self.api_key = api_key # property setter sets Authorization header self.admin = Admin(self) # accessor for admin-related requests @@ -270,7 +274,7 @@ def __repr__(self): auth_info.append("(unauthenticated)") else: auth_info.append("authenticated") - if self.server_info["authentication"].get("links"): + if self.server_info.authentication.links: whoami = self.whoami() auth_info.append("as") if whoami["type"] == "service": @@ -435,7 +439,7 @@ def from_app( raise_server_exceptions=raise_server_exceptions, ) if api_key is UNSET: - if not context.server_info["authentication"]["providers"]: + if not context.server_info.authentication.providers: # This is a single-user server. # Extract the API key from the app and set it. from ..server.settings import get_settings @@ -487,7 +491,7 @@ def which_api_key(self): raise RuntimeError("Not API key is configured for the client.") return handle_error( self.http_client.get( - self.server_info["authentication"]["links"]["apikey"], + self.server_info.authentication.links.apikey, headers={"Accept": MSGPACK_MIME_TYPE}, ) ).json() @@ -516,7 +520,7 @@ def create_api_key(self, scopes=None, expires_in=None, note=None): expires_in = parse_time_string(expires_in) return handle_error( self.http_client.post( - self.server_info["authentication"]["links"]["apikey"], + self.server_info.authentication.links.apikey, headers={"Accept": MSGPACK_MIME_TYPE}, json={"scopes": scopes, "expires_in": expires_in, "note": note}, ) @@ -536,7 +540,7 @@ def revoke_api_key(self, first_eight): Identify the API key to be deleted by passing its first 8 characters. (Any additional characters passed will be truncated.) """ - url_path = self.server_info["authentication"]["links"]["apikey"] + url_path = self.server_info.authentication.links.apikey handle_error( self.http_client.delete( url_path, @@ -605,8 +609,11 @@ def authenticate( Next time, try to automatically authenticate using this session. """ # Obtain tokens via OAuth2 unless the caller has passed them. - providers = self.server_info["authentication"]["providers"] - tokens = prompt_for_credentials(self.http_client, providers) + providers = self.server_info.authentication.providers + tokens = prompt_for_credentials( + self.http_client, + providers, + ) self.configure_auth(tokens, remember_me=remember_me) # These two methods are aliased for convenience. @@ -628,7 +635,7 @@ def configure_auth(self, tokens, remember_me=True): ) # Configure an httpx.Auth instance on the http_client, which # will manage refreshing the tokens as needed. - refresh_url = self.server_info["authentication"]["links"]["refresh_session"] + refresh_url = self.server_info.authentication.links.refresh_session csrf_token = self.http_client.cookies["tiled_csrf"] if remember_me: token_directory = self._token_directory() @@ -671,7 +678,7 @@ def use_cached_tokens(self): success : bool Indicating whether valid cached tokens were found """ - refresh_url = self.server_info["authentication"]["links"]["refresh_session"] + refresh_url = self.server_info.authentication.links.refresh_session csrf_token = self.http_client.cookies["tiled_csrf"] # Try automatically authenticating using cached tokens, if any. @@ -731,7 +738,7 @@ def whoami(self): "Return information about the currently-authenticated user or service." return handle_error( self.http_client.get( - self.server_info["authentication"]["links"]["whoami"], + self.server_info.authentication.links.whoami, headers={"Accept": MSGPACK_MIME_TYPE}, ) ).json() @@ -774,7 +781,7 @@ def revoke_session(self, session_id): """ handle_error( self.http_client.delete( - self.server_info["authentication"]["links"]["revoke_session"].format( + self.server_info.authentication.links.revoke_session.format( session_id=session_id ), headers={"x-csrf": self.http_client.cookies["tiled_csrf"]}, @@ -785,9 +792,9 @@ def revoke_session(self, session_id): class Admin: "Accessor for requests that require administrative privileges." - def __init__(self, context): + def __init__(self, context: Context): self.context = context - self.base_url = context.server_info["links"]["self"] + self.base_url = context.server_info.links["self"] def list_principals(self, offset=0, limit=100): "List Principals (users and services) in the authenticaiton database." diff --git a/tiled/schemas.py b/tiled/schemas.py new file mode 100644 index 000000000..256b1100e --- /dev/null +++ b/tiled/schemas.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel + + +class AboutAuthenticationProvider(BaseModel): + provider: str + mode: Literal["password", "external"] + links: Dict[str, str] + confirmation_message: Optional[str] = None + + +class AboutAuthenticationLinks(BaseModel): + whoami: str + apikey: str + refresh_session: str + revoke_session: str + logout: str + + +class AboutAuthentication(BaseModel): + required: bool + providers: List[AboutAuthenticationProvider] + links: Optional[AboutAuthenticationLinks] = None + + +class About(BaseModel): + api_version: int + library_version: str + formats: Dict[str, List[str]] + aliases: Dict[str, Dict[str, List[str]]] + queries: List[str] + authentication: AboutAuthentication + links: Dict[str, str] + meta: Dict[str, Any] diff --git a/tiled/server/router.py b/tiled/server/router.py index ac8e9799d..97e6dd84c 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -26,6 +26,8 @@ HTTP_422_UNPROCESSABLE_ENTITY, ) +from tiled.schemas import About + from .. import __version__ from ..structures.core import Spec, StructureFamily from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri @@ -67,7 +69,7 @@ router = APIRouter() -@router.get("/", response_model=schemas.About) +@router.get("/", response_model=About) async def about( request: Request, settings: BaseSettings = Depends(get_settings), @@ -130,7 +132,7 @@ async def about( return json_or_msgpack( request, - schemas.About( + About( library_version=__version__, api_version=0, formats={ diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index fa7f039e9..684db5b8c 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -272,43 +272,6 @@ class DeviceCode(pydantic.BaseModel): grant_type: str -class AuthenticationMode(str, enum.Enum): - password = "password" - external = "external" - - -class AboutAuthenticationProvider(pydantic.BaseModel): - provider: str - mode: AuthenticationMode - links: Dict[str, str] - confirmation_message: Optional[str] = None - - -class AboutAuthenticationLinks(pydantic.BaseModel): - whoami: str - apikey: str - refresh_session: str - revoke_session: str - logout: str - - -class AboutAuthentication(pydantic.BaseModel): - required: bool - providers: List[AboutAuthenticationProvider] - links: Optional[AboutAuthenticationLinks] = None - - -class About(pydantic.BaseModel): - api_version: int - library_version: str - formats: Dict[str, List[str]] - aliases: Dict[str, Dict[str, List[str]]] - queries: List[str] - authentication: AboutAuthentication - links: Dict[str, str] - meta: Dict - - class PrincipalType(str, enum.Enum): user = "user" service = "service"