Skip to content

Commit

Permalink
Type hints (#866)
Browse files Browse the repository at this point in the history
* Tighten bounds on Settings

* Replace use of lru_cache with cache
  • Loading branch information
DiamondJoseph authored Jan 27, 2025
1 parent 444515c commit 7e54cd5
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 57 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Write the date in place of the "Unreleased" in the case a new version is release
- Use schemas for describing server configuration on the client side too.
- Refactored Authentication providers to make use of inheritance, adjusted
mode in the `AboutAuthenticationProvider` schema to be `internal`|`external`.
- Improved type hinting and efficiency of caching singleton values

## v0.1.0-b16 (2024-01-23)

Expand Down
4 changes: 2 additions & 2 deletions tiled/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from collections import defaultdict
from datetime import timedelta
from functools import lru_cache
from functools import cache
from pathlib import Path

import jsonschema
Expand All @@ -25,7 +25,7 @@
from .validation_registration import validation_registry as default_validation_registry


@lru_cache(maxsize=1)
@cache
def schema():
"Load the schema for service-side configuration."
import yaml
Expand Down
6 changes: 3 additions & 3 deletions tiled/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import shutil
import sys
import warnings
from functools import lru_cache
from functools import cache
from pathlib import Path

import jsonschema
Expand All @@ -35,7 +35,7 @@
]


@lru_cache(maxsize=1)
@cache
def schema():
"Load the schema for profiles."
import yaml
Expand Down Expand Up @@ -204,7 +204,7 @@ def resolve_precedence(levels):
return combined


@lru_cache(maxsize=1)
@cache
def load_profiles():
"""
Return a mapping of profile_name to (source_path, content).
Expand Down
14 changes: 7 additions & 7 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import urllib.parse
import warnings
from contextlib import asynccontextmanager
from functools import lru_cache, partial
from functools import cache, partial
from pathlib import Path
from typing import List

Expand Down Expand Up @@ -437,15 +437,15 @@ async def unhandled_exception_handler(
response_model=schemas.GetDistinctResponse,
)(patch_route_signature(distinct, query_registry))

@lru_cache(1)
@cache
def override_get_authenticators():
return authenticators

@lru_cache(1)
@cache
def override_get_root_tree():
return tree

@lru_cache(1)
@cache
def override_get_settings():
settings = get_settings()
for item in [
Expand Down Expand Up @@ -771,14 +771,14 @@ async def set_cookies(request: Request, call_next):
app.dependency_overrides[get_settings] = override_get_settings
if query_registry is not None:

@lru_cache(1)
@cache
def override_get_query_registry():
return query_registry

app.dependency_overrides[get_query_registry] = override_get_query_registry
if serialization_registry is not None:

@lru_cache(1)
@cache
def override_get_serialization_registry():
return serialization_registry

Expand All @@ -788,7 +788,7 @@ def override_get_serialization_registry():

if validation_registry is not None:

@lru_cache(1)
@cache
def override_get_validation_registry():
return validation_registry

Expand Down
19 changes: 9 additions & 10 deletions tiled/server/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery
from fastapi.security.utils import get_authorization_scheme_param
from fastapi.templating import Jinja2Templates
from pydantic_settings import BaseSettings
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.sql import func
Expand Down Expand Up @@ -62,7 +61,7 @@
from . import schemas
from .core import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, json_or_msgpack
from .protocols import InternalAuthenticator, UserSessionState
from .settings import get_settings
from .settings import Settings, get_settings
from .utils import API_KEY_COOKIE_NAME, get_authenticators, get_base_url

ALGORITHM = "HS256"
Expand Down Expand Up @@ -214,7 +213,7 @@ async def get_decoded_access_token(
request: Request,
security_scopes: SecurityScopes,
access_token: str = Depends(oauth2_scheme),
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
):
if not access_token:
return None
Expand All @@ -239,7 +238,7 @@ async def get_current_principal(
security_scopes: SecurityScopes,
decoded_access_token: str = Depends(get_decoded_access_token),
api_key: str = Depends(get_api_key),
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
authenticators=Depends(get_authenticators),
db=Depends(get_database_session),
):
Expand Down Expand Up @@ -509,7 +508,7 @@ def build_auth_code_route(authenticator, provider):

async def route(
request: Request,
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
db=Depends(get_database_session),
):
request.state.endpoint = "auth"
Expand Down Expand Up @@ -607,7 +606,7 @@ async def route(
code: str = Form(),
user_code: str = Form(),
state: Optional[str] = None,
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
db=Depends(get_database_session),
):
request.state.endpoint = "auth"
Expand Down Expand Up @@ -670,7 +669,7 @@ def build_device_code_token_route(authenticator, provider):
async def route(
request: Request,
body: schemas.DeviceCode,
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
db=Depends(get_database_session),
):
request.state.endpoint = "auth"
Expand Down Expand Up @@ -710,7 +709,7 @@ def build_handle_credentials_route(authenticator: InternalAuthenticator, provide
async def route(
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
db=Depends(get_database_session),
):
request.state.endpoint = "auth"
Expand Down Expand Up @@ -962,7 +961,7 @@ async def apikey_for_principal(
async def refresh_session(
request: Request,
refresh_token: schemas.RefreshToken,
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
db=Depends(get_database_session),
):
"Obtain a new access token and refresh token."
Expand All @@ -975,7 +974,7 @@ async def refresh_session(
async def revoke_session(
request: Request,
refresh_token: schemas.RefreshToken,
settings: BaseSettings = Depends(get_settings),
settings: Settings = Depends(get_settings),
db=Depends(get_database_session),
):
"Mark a Session as revoked so it cannot be refreshed again."
Expand Down
10 changes: 5 additions & 5 deletions tiled/server/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import lru_cache
from functools import cache
from typing import Optional, Tuple, Union

import pydantic_settings
Expand All @@ -24,25 +24,25 @@
SLICE_REGEX = rf"^{DIM_REGEX}(?:,{DIM_REGEX})*$"


@lru_cache(1)
@cache
def get_query_registry():
"This may be overridden via dependency_overrides."
return default_query_registry


@lru_cache(1)
@cache
def get_deserialization_registry():
"This may be overridden via dependency_overrides."
return default_deserialization_registry


@lru_cache(1)
@cache
def get_serialization_registry():
"This may be overridden via dependency_overrides."
return default_serialization_registry


@lru_cache(1)
@cache
def get_validation_registry():
"This may be overridden via dependency_overrides."
return default_validation_registry
Expand Down
4 changes: 2 additions & 2 deletions tiled/server/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import os
from functools import lru_cache
from functools import cache

from fastapi import APIRouter, Request, Response, Security
from prometheus_client import CONTENT_TYPE_LATEST, Histogram, generate_latest
Expand Down Expand Up @@ -135,7 +135,7 @@ def capture_request_metrics(request, response):
).observe(metrics["compress"]["ratio"])


@lru_cache()
@cache
def prometheus_registry():
"""
Configure prometheus_client.
Expand Down
Loading

0 comments on commit 7e54cd5

Please sign in to comment.