diff --git a/CHANGELOG.md b/CHANGELOG.md index af88283ad..6bd2697d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Write the date in place of the "Unreleased" in the case a new version is release - Use schemas for describing server configuration on the client side too. - Refactored Authentication providers to make use of inheritance, adjusted mode in the `AboutAuthenticationProvider` schema to be `internal`|`external`. +- Improved type hinting and efficiency of caching singleton values ## v0.1.0-b16 (2024-01-23) diff --git a/tiled/config.py b/tiled/config.py index 98248f975..991d2bf48 100644 --- a/tiled/config.py +++ b/tiled/config.py @@ -8,7 +8,7 @@ import warnings from collections import defaultdict from datetime import timedelta -from functools import lru_cache +from functools import cache from pathlib import Path import jsonschema @@ -25,7 +25,7 @@ from .validation_registration import validation_registry as default_validation_registry -@lru_cache(maxsize=1) +@cache def schema(): "Load the schema for service-side configuration." import yaml diff --git a/tiled/profiles.py b/tiled/profiles.py index 35afa193c..b4e35755e 100644 --- a/tiled/profiles.py +++ b/tiled/profiles.py @@ -13,7 +13,7 @@ import shutil import sys import warnings -from functools import lru_cache +from functools import cache from pathlib import Path import jsonschema @@ -35,7 +35,7 @@ ] -@lru_cache(maxsize=1) +@cache def schema(): "Load the schema for profiles." import yaml @@ -204,7 +204,7 @@ def resolve_precedence(levels): return combined -@lru_cache(maxsize=1) +@cache def load_profiles(): """ Return a mapping of profile_name to (source_path, content). diff --git a/tiled/server/app.py b/tiled/server/app.py index 468a6203a..54429b2c5 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -8,7 +8,7 @@ import urllib.parse import warnings from contextlib import asynccontextmanager -from functools import lru_cache, partial +from functools import cache, partial from pathlib import Path from typing import List @@ -437,15 +437,15 @@ async def unhandled_exception_handler( response_model=schemas.GetDistinctResponse, )(patch_route_signature(distinct, query_registry)) - @lru_cache(1) + @cache def override_get_authenticators(): return authenticators - @lru_cache(1) + @cache def override_get_root_tree(): return tree - @lru_cache(1) + @cache def override_get_settings(): settings = get_settings() for item in [ @@ -771,14 +771,14 @@ async def set_cookies(request: Request, call_next): app.dependency_overrides[get_settings] = override_get_settings if query_registry is not None: - @lru_cache(1) + @cache def override_get_query_registry(): return query_registry app.dependency_overrides[get_query_registry] = override_get_query_registry if serialization_registry is not None: - @lru_cache(1) + @cache def override_get_serialization_registry(): return serialization_registry @@ -788,7 +788,7 @@ def override_get_serialization_registry(): if validation_registry is not None: - @lru_cache(1) + @cache def override_get_validation_registry(): return validation_registry diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index c9360a8a3..dc974fe04 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -26,7 +26,6 @@ from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery from fastapi.security.utils import get_authorization_scheme_param from fastapi.templating import Jinja2Templates -from pydantic_settings import BaseSettings from sqlalchemy.future import select from sqlalchemy.orm import selectinload from sqlalchemy.sql import func @@ -62,7 +61,7 @@ from . import schemas from .core import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, json_or_msgpack from .protocols import InternalAuthenticator, UserSessionState -from .settings import get_settings +from .settings import Settings, get_settings from .utils import API_KEY_COOKIE_NAME, get_authenticators, get_base_url ALGORITHM = "HS256" @@ -214,7 +213,7 @@ async def get_decoded_access_token( request: Request, security_scopes: SecurityScopes, access_token: str = Depends(oauth2_scheme), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): if not access_token: return None @@ -239,7 +238,7 @@ async def get_current_principal( security_scopes: SecurityScopes, decoded_access_token: str = Depends(get_decoded_access_token), api_key: str = Depends(get_api_key), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), authenticators=Depends(get_authenticators), db=Depends(get_database_session), ): @@ -509,7 +508,7 @@ def build_auth_code_route(authenticator, provider): async def route( request: Request, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), db=Depends(get_database_session), ): request.state.endpoint = "auth" @@ -607,7 +606,7 @@ async def route( code: str = Form(), user_code: str = Form(), state: Optional[str] = None, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), db=Depends(get_database_session), ): request.state.endpoint = "auth" @@ -670,7 +669,7 @@ def build_device_code_token_route(authenticator, provider): async def route( request: Request, body: schemas.DeviceCode, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), db=Depends(get_database_session), ): request.state.endpoint = "auth" @@ -710,7 +709,7 @@ def build_handle_credentials_route(authenticator: InternalAuthenticator, provide async def route( request: Request, form_data: OAuth2PasswordRequestForm = Depends(), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), db=Depends(get_database_session), ): request.state.endpoint = "auth" @@ -962,7 +961,7 @@ async def apikey_for_principal( async def refresh_session( request: Request, refresh_token: schemas.RefreshToken, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), db=Depends(get_database_session), ): "Obtain a new access token and refresh token." @@ -975,7 +974,7 @@ async def refresh_session( async def revoke_session( request: Request, refresh_token: schemas.RefreshToken, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), db=Depends(get_database_session), ): "Mark a Session as revoked so it cannot be refreshed again." diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index fd2d4d4e8..09b8d8233 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -1,4 +1,4 @@ -from functools import lru_cache +from functools import cache from typing import Optional, Tuple, Union import pydantic_settings @@ -24,25 +24,25 @@ SLICE_REGEX = rf"^{DIM_REGEX}(?:,{DIM_REGEX})*$" -@lru_cache(1) +@cache def get_query_registry(): "This may be overridden via dependency_overrides." return default_query_registry -@lru_cache(1) +@cache def get_deserialization_registry(): "This may be overridden via dependency_overrides." return default_deserialization_registry -@lru_cache(1) +@cache def get_serialization_registry(): "This may be overridden via dependency_overrides." return default_serialization_registry -@lru_cache(1) +@cache def get_validation_registry(): "This may be overridden via dependency_overrides." return default_validation_registry diff --git a/tiled/server/metrics.py b/tiled/server/metrics.py index 2139fe132..d6e0d79ee 100644 --- a/tiled/server/metrics.py +++ b/tiled/server/metrics.py @@ -5,7 +5,7 @@ """ import os -from functools import lru_cache +from functools import cache from fastapi import APIRouter, Request, Response, Security from prometheus_client import CONTENT_TYPE_LATEST, Histogram, generate_latest @@ -135,7 +135,7 @@ def capture_request_metrics(request, response): ).observe(metrics["compress"]["ratio"]) -@lru_cache() +@cache def prometheus_registry(): """ Configure prometheus_client. diff --git a/tiled/server/router.py b/tiled/server/router.py index a1a5293be..836e8b801 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -13,7 +13,6 @@ from jmespath.exceptions import JMESPathError from json_merge_patch import merge as apply_merge_patch from jsonpatch import apply_patch as apply_json_patch -from pydantic_settings import BaseSettings from starlette.status import ( HTTP_200_OK, HTTP_206_PARTIAL_CONTENT, @@ -64,7 +63,7 @@ ) from .file_response_with_range import FileResponseWithRange from .links import links_for_node -from .settings import get_settings +from .settings import Settings, get_settings from .utils import filter_for_access, get_base_url, record_timing router = APIRouter() @@ -73,7 +72,7 @@ @router.get("/", response_model=About) async def about( request: Request, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), authenticators=Depends(get_authenticators), serialization_registry=Depends(get_serialization_registry), query_registry=Depends(get_query_registry), @@ -375,7 +374,7 @@ async def array_block( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a chunk of array-like data. @@ -453,7 +452,7 @@ async def array_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of array-like data. @@ -516,7 +515,7 @@ async def get_table_partition( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a partition (continuous block of rows) from a DataFrame [GET route]. @@ -565,7 +564,7 @@ async def post_table_partition( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a partition (continuous block of rows) from a DataFrame [POST route]. @@ -590,7 +589,7 @@ async def table_partition( format: Optional[str], filename: Optional[str], serialization_registry, - settings: BaseSettings, + settings: Settings, ): """ Fetch a partition (continuous block of rows) from a DataFrame. @@ -647,7 +646,7 @@ async def get_table_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch the data for the given table [GET route]. @@ -675,7 +674,7 @@ async def post_table_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch the data for the given table [POST route]. @@ -698,7 +697,7 @@ async def table_full( format: Optional[str], filename: Optional[str], serialization_registry, - settings: BaseSettings, + settings: Settings, ): """ Fetch the data for the given table. @@ -860,7 +859,7 @@ async def node_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch the data below the given node. @@ -926,7 +925,7 @@ async def get_awkward_buffers( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of AwkwardArray data. @@ -963,7 +962,7 @@ async def post_awkward_buffers( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of AwkwardArray data. @@ -993,7 +992,7 @@ async def _awkward_buffers( format: Optional[str], filename: Optional[str], serialization_registry, - settings: BaseSettings, + settings: Settings, ): structure_family = entry.structure_family structure = entry.structure() @@ -1044,7 +1043,7 @@ async def awkward_full( format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): """ Fetch a slice of AwkwardArray data. @@ -1090,7 +1089,7 @@ async def post_metadata( path: str, body: schemas.PostMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata", "create"]), ): for data_source in body.data_sources: @@ -1120,7 +1119,7 @@ async def post_register( path: str, body: schemas.PostMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata", "create", "register"]), ): return await _create_node( @@ -1138,7 +1137,7 @@ async def _create_node( path: str, body: schemas.PostMetadataRequest, validation_registry, - settings: BaseSettings, + settings: Settings, entry, ): metadata, structure_family, specs = ( @@ -1189,7 +1188,7 @@ async def put_data_source( path: str, data_source: int, body: schemas.PutDataSourceRequest, - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata", "register"]), ): await entry.put_data_source( @@ -1410,7 +1409,7 @@ async def patch_metadata( request: Request, body: schemas.PatchMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata"]), ): if not hasattr(entry, "replace_metadata"): @@ -1473,7 +1472,7 @@ async def put_metadata( request: Request, body: schemas.PutMetadataRequest, validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), entry=SecureEntry(scopes=["write:metadata"]), ): if not hasattr(entry, "replace_metadata"): @@ -1563,7 +1562,7 @@ async def get_asset( id: int, relative_path: Optional[Path] = None, entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): if not settings.expose_raw_assets: raise HTTPException( @@ -1660,7 +1659,7 @@ async def get_asset_manifest( request: Request, id: int, entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): if not settings.expose_raw_assets: raise HTTPException( @@ -1707,7 +1706,7 @@ async def validate_metadata( structure, specs: List[Spec], validation_registry=Depends(get_validation_registry), - settings: BaseSettings = Depends(get_settings), + settings: Settings = Depends(get_settings), ): metadata_modified = False diff --git a/tiled/server/settings.py b/tiled/server/settings.py index 57015f166..e68c9b283 100644 --- a/tiled/server/settings.py +++ b/tiled/server/settings.py @@ -2,7 +2,7 @@ import os import secrets from datetime import timedelta -from functools import lru_cache +from functools import cache from typing import Any, List, Optional from pydantic_settings import BaseSettings @@ -79,6 +79,6 @@ def database_settings(self): ) -@lru_cache() -def get_settings(): +@cache +def get_settings() -> Settings: return Settings()