Skip to content

Commit

Permalink
Move Schemas for describing server configuration to common, use type …
Browse files Browse the repository at this point in the history
…in client
  • Loading branch information
DiamondJoseph committed Jan 24, 2025
1 parent 1de9518 commit d990f1b
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 66 deletions.
2 changes: 1 addition & 1 deletion tiled/_tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_pickle_clients(structure_clients, tmpdir):
raise pytest.skip(f"Could not connect to {API_URL}")
cache = Cache(tmpdir / "http_response_cache.db")
with Context(API_URL, cache=cache) as context:
if parse(context.server_info["library_version"]) < parse(MIN_VERSION):
if parse(context.server_info.library_version) < parse(MIN_VERSION):
raise pytest.skip(
f"Server at {API_URL} is running too old a version to test against."
)
Expand Down
8 changes: 5 additions & 3 deletions tiled/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import orjson
from httpx import URL

from tiled.client.context import Context

from ..structures.core import STRUCTURE_TYPES, Spec, StructureFamily
from ..structures.data_source import DataSource
from ..utils import UNCHANGED, DictView, ListView, patch_mimetypes, safe_json_dump
Expand Down Expand Up @@ -112,7 +114,7 @@ class BaseClient:

def __init__(
self,
context,
context: Context,
*,
item,
structure_clients,
Expand Down Expand Up @@ -415,9 +417,9 @@ def formats(self):
"List formats that the server can export this data as."
formats = set()
for spec in self.item["attributes"]["specs"]:
formats.update(self.context.server_info["formats"].get(spec["name"], []))
formats.update(self.context.server_info.formats.get(spec["name"], []))
formats.update(
self.context.server_info["formats"][
self.context.server_info.formats[
self.item["attributes"]["structure_family"]
]
)
Expand Down
6 changes: 3 additions & 3 deletions tiled/client/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def from_uri(


def from_context(
context,
context: Context,
structure_clients="numpy",
node_path_parts=None,
include_data_sources=False,
Expand Down Expand Up @@ -126,8 +126,8 @@ def from_context(
# 2. If there are cached valid credentials for this server, use them.
# 3. If not, and the server requires authentication, prompt for authentication.
if context.api_key is None:
auth_is_required = context.server_info["authentication"]["required"]
has_providers = len(context.server_info["authentication"]["providers"]) > 0
auth_is_required = context.server_info.authentication.required
has_providers = len(context.server_info.authentication.providers) > 0
if auth_is_required and not has_providers:
raise RuntimeError(
"""This server requires API key authentication.
Expand Down
35 changes: 20 additions & 15 deletions tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

import httpx
import platformdirs
from pydantic import TypeAdapter

from tiled.server.schemas import AboutAuthenticationProvider
from tiled.schemas import About, AboutAuthenticationProvider

from .._version import __version__ as tiled_version
from ..utils import UNSET, DictView, parse_time_string
Expand Down Expand Up @@ -257,7 +258,7 @@ def __init__(
# (2) Let the server set the CSRF cookie.
# No authentication has been set up yet, so these requests will be unauthenticated.
# https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#double-submit-cookie
self.server_info = handle_error(
server_info = handle_error(
self.http_client.get(
self.api_uri,
headers={
Expand All @@ -266,6 +267,7 @@ def __init__(
},
)
).json()
self.server_info: About = TypeAdapter(About).validate_python(server_info)
self.api_key = api_key # property setter sets Authorization header
self.admin = Admin(self) # accessor for admin-related requests

Expand All @@ -275,7 +277,7 @@ def __repr__(self):
auth_info.append("(unauthenticated)")
else:
auth_info.append("authenticated")
if self.server_info["authentication"].get("links"):
if self.server_info.authentication.links:
whoami = self.whoami()
auth_info.append("as")
if whoami["type"] == "service":
Expand Down Expand Up @@ -440,7 +442,7 @@ def from_app(
raise_server_exceptions=raise_server_exceptions,
)
if api_key is UNSET:
if not context.server_info["authentication"]["providers"]:
if not context.server_info.authentication.providers:
# This is a single-user server.
# Extract the API key from the app and set it.
from ..server.settings import get_settings
Expand Down Expand Up @@ -492,7 +494,7 @@ def which_api_key(self):
raise RuntimeError("Not API key is configured for the client.")
return handle_error(
self.http_client.get(
self.server_info["authentication"]["links"]["apikey"],
self.server_info.authentication.links.apikey,
headers={"Accept": MSGPACK_MIME_TYPE},
)
).json()
Expand Down Expand Up @@ -521,7 +523,7 @@ def create_api_key(self, scopes=None, expires_in=None, note=None):
expires_in = parse_time_string(expires_in)
return handle_error(
self.http_client.post(
self.server_info["authentication"]["links"]["apikey"],
self.server_info.authentication.links.apikey,
headers={"Accept": MSGPACK_MIME_TYPE},
json={"scopes": scopes, "expires_in": expires_in, "note": note},
)
Expand All @@ -541,7 +543,7 @@ def revoke_api_key(self, first_eight):
Identify the API key to be deleted by passing its first 8 characters.
(Any additional characters passed will be truncated.)
"""
url_path = self.server_info["authentication"]["links"]["apikey"]
url_path = self.server_info.authentication.links.apikey
handle_error(
self.http_client.delete(
url_path,
Expand Down Expand Up @@ -610,8 +612,11 @@ def authenticate(
Next time, try to automatically authenticate using this session.
"""
# Obtain tokens via OAuth2 unless the caller has passed them.
providers = self.server_info["authentication"]["providers"]
tokens = prompt_for_credentials(self.http_client, providers)
providers = self.server_info.authentication.providers
tokens = prompt_for_credentials(
self.http_client,
providers,
)
self.configure_auth(tokens, remember_me=remember_me)

# These two methods are aliased for convenience.
Expand All @@ -633,7 +638,7 @@ def configure_auth(self, tokens, remember_me=True):
)
# Configure an httpx.Auth instance on the http_client, which
# will manage refreshing the tokens as needed.
refresh_url = self.server_info["authentication"]["links"]["refresh_session"]
refresh_url = self.server_info.authentication.links.refresh_session
csrf_token = self.http_client.cookies["tiled_csrf"]
if remember_me:
token_directory = self._token_directory()
Expand Down Expand Up @@ -676,7 +681,7 @@ def use_cached_tokens(self):
success : bool
Indicating whether valid cached tokens were found
"""
refresh_url = self.server_info["authentication"]["links"]["refresh_session"]
refresh_url = self.server_info.authentication.links.refresh_session
csrf_token = self.http_client.cookies["tiled_csrf"]

# Try automatically authenticating using cached tokens, if any.
Expand Down Expand Up @@ -736,7 +741,7 @@ def whoami(self):
"Return information about the currently-authenticated user or service."
return handle_error(
self.http_client.get(
self.server_info["authentication"]["links"]["whoami"],
self.server_info.authentication.links.whoami,
headers={"Accept": MSGPACK_MIME_TYPE},
)
).json()
Expand Down Expand Up @@ -779,7 +784,7 @@ def revoke_session(self, session_id):
"""
handle_error(
self.http_client.delete(
self.server_info["authentication"]["links"]["revoke_session"].format(
self.server_info.authentication.links.revoke_session.format(
session_id=session_id
),
headers={"x-csrf": self.http_client.cookies["tiled_csrf"]},
Expand All @@ -790,9 +795,9 @@ def revoke_session(self, session_id):
class Admin:
"Accessor for requests that require administrative privileges."

def __init__(self, context):
def __init__(self, context: Context):
self.context = context
self.base_url = context.server_info["links"]["self"]
self.base_url = context.server_info.links["self"]

def list_principals(self, offset=0, limit=100):
"List Principals (users and services) in the authenticaiton database."
Expand Down
34 changes: 34 additions & 0 deletions tiled/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel


class AboutAuthenticationProvider(BaseModel):
provider: str
mode: Literal["internal", "external"]
links: Dict[str, str]
confirmation_message: Optional[str] = None


class AboutAuthenticationLinks(BaseModel):
whoami: str
apikey: str
refresh_session: str
revoke_session: str
logout: str


class AboutAuthentication(BaseModel):
required: bool
providers: List[AboutAuthenticationProvider]
links: Optional[AboutAuthenticationLinks] = None


class About(BaseModel):
api_version: int
library_version: str
formats: Dict[str, List[str]]
aliases: Dict[str, Dict[str, List[str]]]
queries: List[str]
authentication: AboutAuthentication
links: Dict[str, str]
meta: Dict[str, Any]
3 changes: 2 additions & 1 deletion tiled/server/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
HTTP_422_UNPROCESSABLE_ENTITY,
)

from tiled.schemas import About
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator

from .. import __version__
Expand Down Expand Up @@ -132,7 +133,7 @@ async def about(

return json_or_msgpack(
request,
schemas.About(
About(
library_version=__version__,
api_version=0,
formats={
Expand Down
44 changes: 1 addition & 43 deletions tiled/server/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,7 @@
import enum
import uuid
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Literal,
Optional,
TypeVar,
Union,
)
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union

import pydantic.generics
from pydantic import ConfigDict, Field, StringConstraints
Expand Down Expand Up @@ -282,38 +272,6 @@ class DeviceCode(pydantic.BaseModel):
grant_type: str


class AboutAuthenticationProvider(pydantic.BaseModel):
provider: str
mode: Literal["internal", "external"]
links: Dict[str, str]
confirmation_message: Optional[str] = None


class AboutAuthenticationLinks(pydantic.BaseModel):
whoami: str
apikey: str
refresh_session: str
revoke_session: str
logout: str


class AboutAuthentication(pydantic.BaseModel):
required: bool
providers: List[AboutAuthenticationProvider]
links: Optional[AboutAuthenticationLinks] = None


class About(pydantic.BaseModel):
api_version: int
library_version: str
formats: Dict[str, List[str]]
aliases: Dict[str, Dict[str, List[str]]]
queries: List[str]
authentication: AboutAuthentication
links: Dict[str, str]
meta: Dict


class PrincipalType(str, enum.Enum):
user = "user"
service = "service"
Expand Down

0 comments on commit d990f1b

Please sign in to comment.