diff --git a/oembedpy/application.py b/oembedpy/application.py index 572f08f..41f0314 100644 --- a/oembedpy/application.py +++ b/oembedpy/application.py @@ -2,8 +2,9 @@ import json import logging +import pickle import time -from typing import Optional +from typing import Dict, Optional import httpx @@ -11,7 +12,7 @@ from oembedpy import consumer, discovery from oembedpy.provider import ProviderRegistry -from oembedpy.types import Content +from oembedpy.types import CachedContent, Content logger = logging.getLogger(__name__) @@ -20,6 +21,7 @@ class Oembed: """Application of oEmbed.""" _registry: ProviderRegistry + _cache: Dict[str, CachedContent] def __init__(self): # noqa: D107 pass @@ -27,6 +29,7 @@ def __init__(self): # noqa: D107 def init(self): resp = httpx.get("https://oembed.com/providers.json") self._registry = ProviderRegistry.from_dict(resp.json()) + self._cache = {} def fetch( self, @@ -46,7 +49,13 @@ def fetch( params.max_width = max_width if max_height: params.max_height = max_height + # + now = time.mktime(time.localtime()) + if params in self._cache and now <= self._cache[params].expired: + return self._cache[params].content content = consumer.fetch_content(api_url, params) + if content.cache_age: + self._cache[params] = CachedContent(now + int(content.cache_age), content) return content @@ -55,12 +64,21 @@ class Workspace(Oembed): def __init__(self): self._dirs = PlatformDirs("oembedpy") + self._cache = {} + + def __del__(self): + cache_db = self.cache_dir / "db.pickle" + cache_db.write_bytes(pickle.dumps(self._cache)) @property def cache_dir(self): return self._dirs.user_data_path def init(self): + self.init_providers() + self.init_caches() + + def init_providers(self): providers_json = self.cache_dir / "providers.json" use_cache = providers_json.exists() if use_cache: @@ -79,3 +97,8 @@ def init(self): providers_json.write_text(resp.text) self._registry = ProviderRegistry.from_dict(providers_data) + + def init_caches(self): + cache_db = self.cache_dir / "db.pickle" + if cache_db.exists(): + self._cache = pickle.loads(cache_db.read_bytes()) diff --git a/oembedpy/consumer.py b/oembedpy/consumer.py index 1b8eaca..9c5faea 100644 --- a/oembedpy/consumer.py +++ b/oembedpy/consumer.py @@ -21,6 +21,9 @@ class RequestParameters: max_width: Optional[int] = None max_height: Optional[int] = None + def __hash__(self): + return hash((self.url, self.format, self.max_width, self.max_height)) + def to_dict(self) -> Dict[str, str]: """Make dict object from properties.""" data = {"url": self.url} diff --git a/oembedpy/types.py b/oembedpy/types.py index 246248d..a8cdc27 100644 --- a/oembedpy/types.py +++ b/oembedpy/types.py @@ -6,7 +6,7 @@ from dataclasses import asdict, dataclass from inspect import signature -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, Dict, NamedTuple, Optional, Type, TypeVar, Union T = TypeVar("T", bound="_Required") @@ -103,3 +103,8 @@ class Rich(_Optionals, _Rich, _Required): Content = Union[Photo, Video, Link, Rich] """Collection of oEmbed content types.""" + + +class CachedContent(NamedTuple): + expired: float + content: Content diff --git a/tests/test_application.py b/tests/test_application.py index 0849bba..a5db5dc 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,4 +1,5 @@ # flake8: noqa +import pickle import subprocess from pathlib import Path @@ -70,3 +71,24 @@ def test_purge_cached_providers_json( other_workspace = application.Workspace() other_workspace.init() assert spy.call_count == 2 + + def test_initialized_response_cache( + self, mocked_workspace: application.Workspace, tmp_path + ): + mocked_workspace.init() + assert (tmp_path / "providers.json").exists() + mocked_workspace.__del__() + assert (tmp_path / "db.pickle").exists() + + def test_initialized_response_cache( + self, mocked_workspace: application.Workspace, tmp_path + ): + mocked_workspace.init() + mocked_workspace.fetch( + "https://bsky.app/profile/attakei.dev/post/3kr76heazfp2i" + ) + mocked_workspace.__del__() + db_path = tmp_path / "db.pickle" + assert db_path.exists() + db = pickle.loads(db_path.read_bytes()) + assert len(db) == 1