Skip to content

Commit

Permalink
Move Schemas for describing server configuration (#862)
Browse files Browse the repository at this point in the history
* Move Schemas for describing server configuration

* Update CHANGELOG

---------

Co-authored-by: Dan Allan <[email protected]>
  • Loading branch information
DiamondJoseph and danielballan authored Jan 24, 2025
1 parent bed06d6 commit ff14dd1
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 69 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ Write the date in place of the "Unreleased" in the case a new version is release

# Changelog

## Unreleased

### Maintenance

- Make depedencies shared by client and server into core dependencies.
- Use schemas for describing server configuration on the client side too.

## v0.1.0-b16 (2024-01-23)

### Maintenance
Expand Down
2 changes: 1 addition & 1 deletion tiled/_tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_pickle_clients(structure_clients, tmpdir):
raise pytest.skip(f"Could not connect to {API_URL}")
cache = Cache(tmpdir / "http_response_cache.db")
with Context(API_URL, cache=cache) as context:
if parse(context.server_info["library_version"]) < parse(MIN_VERSION):
if parse(context.server_info.library_version) < parse(MIN_VERSION):
raise pytest.skip(
f"Server at {API_URL} is running too old a version to test against."
)
Expand Down
8 changes: 5 additions & 3 deletions tiled/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import orjson
from httpx import URL

from tiled.client.context import Context

from ..structures.core import STRUCTURE_TYPES, Spec, StructureFamily
from ..structures.data_source import DataSource
from ..utils import UNCHANGED, DictView, ListView, patch_mimetypes, safe_json_dump
Expand Down Expand Up @@ -112,7 +114,7 @@ class BaseClient:

def __init__(
self,
context,
context: Context,
*,
item,
structure_clients,
Expand Down Expand Up @@ -415,9 +417,9 @@ def formats(self):
"List formats that the server can export this data as."
formats = set()
for spec in self.item["attributes"]["specs"]:
formats.update(self.context.server_info["formats"].get(spec["name"], []))
formats.update(self.context.server_info.formats.get(spec["name"], []))
formats.update(
self.context.server_info["formats"][
self.context.server_info.formats[
self.item["attributes"]["structure_family"]
]
)
Expand Down
6 changes: 3 additions & 3 deletions tiled/client/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def from_uri(


def from_context(
context,
context: Context,
structure_clients="numpy",
node_path_parts=None,
include_data_sources=False,
Expand Down Expand Up @@ -126,8 +126,8 @@ def from_context(
# 2. If there are cached valid credentials for this server, use them.
# 3. If not, and the server requires authentication, prompt for authentication.
if context.api_key is None:
auth_is_required = context.server_info["authentication"]["required"]
has_providers = len(context.server_info["authentication"]["providers"]) > 0
auth_is_required = context.server_info.authentication.required
has_providers = len(context.server_info.authentication.providers) > 0
if auth_is_required and not has_providers:
raise RuntimeError(
"""This server requires API key authentication.
Expand Down
56 changes: 33 additions & 23 deletions tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import urllib.parse
import warnings
from pathlib import Path
from typing import List
from urllib.parse import parse_qs, urlparse

import httpx
import platformdirs
from pydantic import TypeAdapter

from tiled.schemas import About, AboutAuthenticationProvider

from .._version import __version__ as tiled_version
from ..utils import UNSET, DictView, parse_time_string
Expand Down Expand Up @@ -42,11 +46,13 @@ def raise_if_cannot_prompt():
)


def identity_provider_input(providers, provider=None):
def identity_provider_input(
providers: List[AboutAuthenticationProvider],
) -> AboutAuthenticationProvider:
while True:
print("Authenticaiton providers:")
print("Authentication providers:")
for i, spec in enumerate(providers, start=1):
print(f"{i} - {spec['provider']}")
print(f"{i} - {spec.provider}")
raw_choice = input(
"Choose an authentication provider (or press Enter to cancel): "
)
Expand Down Expand Up @@ -81,18 +87,18 @@ class PasswordRejected(RuntimeError):
pass


def prompt_for_credentials(http_client, providers):
def prompt_for_credentials(http_client, providers: List[AboutAuthenticationProvider]):
"""
Prompt for credentials or third-party login at an interactive terminal.
"""
if len(providers) == 1:
# There is only one choice, so no need to prompt the user.
(spec,) = providers
spec = providers[0]
else:
spec = identity_provider_input(providers)
auth_endpoint = spec["links"]["auth_endpoint"]
provider = spec["provider"]
mode = spec["mode"]
auth_endpoint = spec.links["auth_endpoint"]
provider = spec.provider
mode = spec.mode
if mode == "password":
# Prompt for username, password at terminal.
username = username_input()
Expand Down Expand Up @@ -124,7 +130,7 @@ def prompt_for_credentials(http_client, providers):
tokens = device_code_grant(http_client, auth_endpoint)
else:
raise ValueError(f"Server has unknown authentication mechanism {mode!r}")
confirmation_message = spec.get("confirmation_message")
confirmation_message = spec.confirmation_message
if confirmation_message:
username = tokens["identity"]["id"]
print(confirmation_message.format(id=username))
Expand Down Expand Up @@ -252,7 +258,7 @@ def __init__(
# (2) Let the server set the CSRF cookie.
# No authentication has been set up yet, so these requests will be unauthenticated.
# https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#double-submit-cookie
self.server_info = handle_error(
server_info = handle_error(
self.http_client.get(
self.api_uri,
headers={
Expand All @@ -261,6 +267,7 @@ def __init__(
},
)
).json()
self.server_info: About = TypeAdapter(About).validate_python(server_info)
self.api_key = api_key # property setter sets Authorization header
self.admin = Admin(self) # accessor for admin-related requests

Expand All @@ -270,7 +277,7 @@ def __repr__(self):
auth_info.append("(unauthenticated)")
else:
auth_info.append("authenticated")
if self.server_info["authentication"].get("links"):
if self.server_info.authentication.links:
whoami = self.whoami()
auth_info.append("as")
if whoami["type"] == "service":
Expand Down Expand Up @@ -435,7 +442,7 @@ def from_app(
raise_server_exceptions=raise_server_exceptions,
)
if api_key is UNSET:
if not context.server_info["authentication"]["providers"]:
if not context.server_info.authentication.providers:
# This is a single-user server.
# Extract the API key from the app and set it.
from ..server.settings import get_settings
Expand Down Expand Up @@ -487,7 +494,7 @@ def which_api_key(self):
raise RuntimeError("Not API key is configured for the client.")
return handle_error(
self.http_client.get(
self.server_info["authentication"]["links"]["apikey"],
self.server_info.authentication.links.apikey,
headers={"Accept": MSGPACK_MIME_TYPE},
)
).json()
Expand Down Expand Up @@ -516,7 +523,7 @@ def create_api_key(self, scopes=None, expires_in=None, note=None):
expires_in = parse_time_string(expires_in)
return handle_error(
self.http_client.post(
self.server_info["authentication"]["links"]["apikey"],
self.server_info.authentication.links.apikey,
headers={"Accept": MSGPACK_MIME_TYPE},
json={"scopes": scopes, "expires_in": expires_in, "note": note},
)
Expand All @@ -536,7 +543,7 @@ def revoke_api_key(self, first_eight):
Identify the API key to be deleted by passing its first 8 characters.
(Any additional characters passed will be truncated.)
"""
url_path = self.server_info["authentication"]["links"]["apikey"]
url_path = self.server_info.authentication.links.apikey
handle_error(
self.http_client.delete(
url_path,
Expand Down Expand Up @@ -605,8 +612,11 @@ def authenticate(
Next time, try to automatically authenticate using this session.
"""
# Obtain tokens via OAuth2 unless the caller has passed them.
providers = self.server_info["authentication"]["providers"]
tokens = prompt_for_credentials(self.http_client, providers)
providers = self.server_info.authentication.providers
tokens = prompt_for_credentials(
self.http_client,
providers,
)
self.configure_auth(tokens, remember_me=remember_me)

# These two methods are aliased for convenience.
Expand All @@ -628,7 +638,7 @@ def configure_auth(self, tokens, remember_me=True):
)
# Configure an httpx.Auth instance on the http_client, which
# will manage refreshing the tokens as needed.
refresh_url = self.server_info["authentication"]["links"]["refresh_session"]
refresh_url = self.server_info.authentication.links.refresh_session
csrf_token = self.http_client.cookies["tiled_csrf"]
if remember_me:
token_directory = self._token_directory()
Expand Down Expand Up @@ -671,7 +681,7 @@ def use_cached_tokens(self):
success : bool
Indicating whether valid cached tokens were found
"""
refresh_url = self.server_info["authentication"]["links"]["refresh_session"]
refresh_url = self.server_info.authentication.links.refresh_session
csrf_token = self.http_client.cookies["tiled_csrf"]

# Try automatically authenticating using cached tokens, if any.
Expand Down Expand Up @@ -731,7 +741,7 @@ def whoami(self):
"Return information about the currently-authenticated user or service."
return handle_error(
self.http_client.get(
self.server_info["authentication"]["links"]["whoami"],
self.server_info.authentication.links.whoami,
headers={"Accept": MSGPACK_MIME_TYPE},
)
).json()
Expand Down Expand Up @@ -774,7 +784,7 @@ def revoke_session(self, session_id):
"""
handle_error(
self.http_client.delete(
self.server_info["authentication"]["links"]["revoke_session"].format(
self.server_info.authentication.links.revoke_session.format(
session_id=session_id
),
headers={"x-csrf": self.http_client.cookies["tiled_csrf"]},
Expand All @@ -785,9 +795,9 @@ def revoke_session(self, session_id):
class Admin:
"Accessor for requests that require administrative privileges."

def __init__(self, context):
def __init__(self, context: Context):
self.context = context
self.base_url = context.server_info["links"]["self"]
self.base_url = context.server_info.links["self"]

def list_principals(self, offset=0, limit=100):
"List Principals (users and services) in the authenticaiton database."
Expand Down
35 changes: 35 additions & 0 deletions tiled/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Dict, List, Literal, Optional

from pydantic import BaseModel


class AboutAuthenticationProvider(BaseModel):
provider: str
mode: Literal["password", "external"]
links: Dict[str, str]
confirmation_message: Optional[str] = None


class AboutAuthenticationLinks(BaseModel):
whoami: str
apikey: str
refresh_session: str
revoke_session: str
logout: str


class AboutAuthentication(BaseModel):
required: bool
providers: List[AboutAuthenticationProvider]
links: Optional[AboutAuthenticationLinks] = None


class About(BaseModel):
api_version: int
library_version: str
formats: Dict[str, List[str]]
aliases: Dict[str, Dict[str, List[str]]]
queries: List[str]
authentication: AboutAuthentication
links: Dict[str, str]
meta: Dict[str, Any]
6 changes: 4 additions & 2 deletions tiled/server/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
HTTP_422_UNPROCESSABLE_ENTITY,
)

from tiled.schemas import About

from .. import __version__
from ..structures.core import Spec, StructureFamily
from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri
Expand Down Expand Up @@ -67,7 +69,7 @@
router = APIRouter()


@router.get("/", response_model=schemas.About)
@router.get("/", response_model=About)
async def about(
request: Request,
settings: BaseSettings = Depends(get_settings),
Expand Down Expand Up @@ -130,7 +132,7 @@ async def about(

return json_or_msgpack(
request,
schemas.About(
About(
library_version=__version__,
api_version=0,
formats={
Expand Down
37 changes: 0 additions & 37 deletions tiled/server/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,43 +272,6 @@ class DeviceCode(pydantic.BaseModel):
grant_type: str


class AuthenticationMode(str, enum.Enum):
password = "password"
external = "external"


class AboutAuthenticationProvider(pydantic.BaseModel):
provider: str
mode: AuthenticationMode
links: Dict[str, str]
confirmation_message: Optional[str] = None


class AboutAuthenticationLinks(pydantic.BaseModel):
whoami: str
apikey: str
refresh_session: str
revoke_session: str
logout: str


class AboutAuthentication(pydantic.BaseModel):
required: bool
providers: List[AboutAuthenticationProvider]
links: Optional[AboutAuthenticationLinks] = None


class About(pydantic.BaseModel):
api_version: int
library_version: str
formats: Dict[str, List[str]]
aliases: Dict[str, Dict[str, List[str]]]
queries: List[str]
authentication: AboutAuthentication
links: Dict[str, str]
meta: Dict


class PrincipalType(str, enum.Enum):
user = "user"
service = "service"
Expand Down

0 comments on commit ff14dd1

Please sign in to comment.