From 9057f07394f0a61cf77c4c4dd7e9fd9dabc503f6 Mon Sep 17 00:00:00 2001 From: Jordan Sanders Date: Tue, 27 Dec 2022 09:51:21 -0600 Subject: [PATCH 1/4] Test expiration I don't think expiration works as intended (or at least as I would expect it to). The response keys are correctly expired, but nothing is ever removed from the idempotency keys set. This has two consequences: - once a response key expires, requests using the same idempotency key will return 409 forever - the idempotency keys set grows without bound Instead, I would expect expiration to also apply to the idempotency keys set. Once the expiration window has passed, a new request with the same idempotency key should issue a new request. This commit only introduces the tests without implementing a solution. For MemoryBackend, I'd propose storing the expiry alongside the idempotency key. For RedisBackend, I'd propose something along the lines of https://github.com/redis/redis/issues/135#issuecomment-2361996 using a sorted set to purge expired idempotency keys. But I wanted to run the change by you before taking action on either approach. --- idempotency_header_middleware/backends/base.py | 3 ++- idempotency_header_middleware/backends/memory.py | 4 ++-- idempotency_header_middleware/backends/redis.py | 6 +++--- tests/conftest.py | 14 ++++++++++++-- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/idempotency_header_middleware/backends/base.py b/idempotency_header_middleware/backends/base.py index 0c0aade..f0429ae 100644 --- a/idempotency_header_middleware/backends/base.py +++ b/idempotency_header_middleware/backends/base.py @@ -3,9 +3,10 @@ from starlette.responses import Response +DEFAULT_EXPIRY = 60 * 60 * 24 class Backend(ABC): - expiry: Optional[int] = 60 * 60 * 24 + expiry: Optional[int] = DEFAULT_EXPIRY @abstractmethod async def get_stored_response(self, idempotency_key: str) -> Optional[Response]: diff --git a/idempotency_header_middleware/backends/memory.py b/idempotency_header_middleware/backends/memory.py index a162b8a..a20a8f1 100644 --- a/idempotency_header_middleware/backends/memory.py +++ b/idempotency_header_middleware/backends/memory.py @@ -4,7 +4,7 @@ from starlette.responses import JSONResponse -from idempotency_header_middleware.backends.base import Backend +from idempotency_header_middleware.backends.base import Backend, DEFAULT_EXPIRY @dataclass() @@ -19,7 +19,7 @@ class MemoryBackend(Backend): The backend is mainly here for local development or testing. """ - expiry: Optional[int] = 60 * 60 * 24 + expiry: Optional[int] = DEFAULT_EXPIRY response_store: Dict[str, Dict[str, Any]] = field(default_factory=dict) keys: Set[str] = field(default_factory=set) diff --git a/idempotency_header_middleware/backends/redis.py b/idempotency_header_middleware/backends/redis.py index 46bba0a..1294093 100644 --- a/idempotency_header_middleware/backends/redis.py +++ b/idempotency_header_middleware/backends/redis.py @@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse from redis.asyncio import Redis -from idempotency_header_middleware.backends.base import Backend +from idempotency_header_middleware.backends.base import Backend, DEFAULT_EXPIRY @dataclass() @@ -15,7 +15,7 @@ def __init__( redis: Redis, keys_key: str = 'idempotency-key-keys', response_key: str = 'idempotency-key-responses', - expiry: int = 60 * 60 * 24, + expiry: int = DEFAULT_EXPIRY, ): self.redis = redis self.KEYS_KEY = keys_key @@ -49,7 +49,7 @@ async def store_response_data(self, idempotency_key: str, payload: dict, status_ await self.redis.set(payload_key, json.dumps(payload)) await self.redis.set(status_code_key, status_code) - if self.expiry: + if self.expiry >= 0: await self.redis.expire(payload_key, self.expiry) await self.redis.expire(status_code_key, self.expiry) diff --git a/tests/conftest.py b/tests/conftest.py index e97847b..5236fe3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ StreamingResponse, ) +from idempotency_header_middleware.backends.base import DEFAULT_EXPIRY from idempotency_header_middleware.backends.redis import RedisBackend from idempotency_header_middleware.middleware import IdempotencyHeaderMiddleware @@ -70,13 +71,22 @@ def _configure_logging(): def method_config(request): return request.param +expirations = { + 'default': DEFAULT_EXPIRY, + 'immediately': 0, +} + +@pytest.fixture(scope='session', ids=expirations.keys(), params=expirations.values()) +def expiry(request): + return request.param + @pytest.fixture(scope='session', autouse=True) -def app_with_middleware(method_config): +def app_with_middleware(method_config, expiry): app.add_middleware( IdempotencyHeaderMiddleware, enforce_uuid4_formatting=True, - backend=RedisBackend(redis=fakeredis.aioredis.FakeRedis(decode_responses=True)), + backend=RedisBackend(redis=fakeredis.aioredis.FakeRedis(decode_responses=True), expiry=expiry), applicable_methods=method_config['setting'], ) yield app From e63463e601482aa9ee1d31edbd4e17492bd63eee Mon Sep 17 00:00:00 2001 From: Jordan Sanders Date: Wed, 4 Jan 2023 10:03:38 -0600 Subject: [PATCH 2/4] Revert "Test expiration" This reverts commit 9057f07394f0a61cf77c4c4dd7e9fd9dabc503f6. --- idempotency_header_middleware/backends/base.py | 3 +-- idempotency_header_middleware/backends/memory.py | 4 ++-- idempotency_header_middleware/backends/redis.py | 6 +++--- tests/conftest.py | 14 ++------------ 4 files changed, 8 insertions(+), 19 deletions(-) diff --git a/idempotency_header_middleware/backends/base.py b/idempotency_header_middleware/backends/base.py index f0429ae..0c0aade 100644 --- a/idempotency_header_middleware/backends/base.py +++ b/idempotency_header_middleware/backends/base.py @@ -3,10 +3,9 @@ from starlette.responses import Response -DEFAULT_EXPIRY = 60 * 60 * 24 class Backend(ABC): - expiry: Optional[int] = DEFAULT_EXPIRY + expiry: Optional[int] = 60 * 60 * 24 @abstractmethod async def get_stored_response(self, idempotency_key: str) -> Optional[Response]: diff --git a/idempotency_header_middleware/backends/memory.py b/idempotency_header_middleware/backends/memory.py index a20a8f1..a162b8a 100644 --- a/idempotency_header_middleware/backends/memory.py +++ b/idempotency_header_middleware/backends/memory.py @@ -4,7 +4,7 @@ from starlette.responses import JSONResponse -from idempotency_header_middleware.backends.base import Backend, DEFAULT_EXPIRY +from idempotency_header_middleware.backends.base import Backend @dataclass() @@ -19,7 +19,7 @@ class MemoryBackend(Backend): The backend is mainly here for local development or testing. """ - expiry: Optional[int] = DEFAULT_EXPIRY + expiry: Optional[int] = 60 * 60 * 24 response_store: Dict[str, Dict[str, Any]] = field(default_factory=dict) keys: Set[str] = field(default_factory=set) diff --git a/idempotency_header_middleware/backends/redis.py b/idempotency_header_middleware/backends/redis.py index 1294093..46bba0a 100644 --- a/idempotency_header_middleware/backends/redis.py +++ b/idempotency_header_middleware/backends/redis.py @@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse from redis.asyncio import Redis -from idempotency_header_middleware.backends.base import Backend, DEFAULT_EXPIRY +from idempotency_header_middleware.backends.base import Backend @dataclass() @@ -15,7 +15,7 @@ def __init__( redis: Redis, keys_key: str = 'idempotency-key-keys', response_key: str = 'idempotency-key-responses', - expiry: int = DEFAULT_EXPIRY, + expiry: int = 60 * 60 * 24, ): self.redis = redis self.KEYS_KEY = keys_key @@ -49,7 +49,7 @@ async def store_response_data(self, idempotency_key: str, payload: dict, status_ await self.redis.set(payload_key, json.dumps(payload)) await self.redis.set(status_code_key, status_code) - if self.expiry >= 0: + if self.expiry: await self.redis.expire(payload_key, self.expiry) await self.redis.expire(status_code_key, self.expiry) diff --git a/tests/conftest.py b/tests/conftest.py index 5236fe3..e97847b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,6 @@ StreamingResponse, ) -from idempotency_header_middleware.backends.base import DEFAULT_EXPIRY from idempotency_header_middleware.backends.redis import RedisBackend from idempotency_header_middleware.middleware import IdempotencyHeaderMiddleware @@ -71,22 +70,13 @@ def _configure_logging(): def method_config(request): return request.param -expirations = { - 'default': DEFAULT_EXPIRY, - 'immediately': 0, -} - -@pytest.fixture(scope='session', ids=expirations.keys(), params=expirations.values()) -def expiry(request): - return request.param - @pytest.fixture(scope='session', autouse=True) -def app_with_middleware(method_config, expiry): +def app_with_middleware(method_config): app.add_middleware( IdempotencyHeaderMiddleware, enforce_uuid4_formatting=True, - backend=RedisBackend(redis=fakeredis.aioredis.FakeRedis(decode_responses=True), expiry=expiry), + backend=RedisBackend(redis=fakeredis.aioredis.FakeRedis(decode_responses=True)), applicable_methods=method_config['setting'], ) yield app From 1532b90ad921e3178e336192f42e1fc6d988d793 Mon Sep 17 00:00:00 2001 From: Jordan Sanders Date: Wed, 4 Jan 2023 10:26:33 -0600 Subject: [PATCH 3/4] Expire idempotency keys Previously, idempotency keys were never being expired. This had two consequences: - once a response key expires, requests using the same idempotency key will return 409 forever - the idempotency keys set grows without bound There's a new `expire_idempotency_keys` abstract method that implementations of Backend are responsible for implementing. The middleware calls it on every request. For the Memory backend, we now store idempotency keys in a dict intead of a set. The value of each key is its expiration time. We then iterate over all of the items in the dict and remove the ones with an expiration earlier than the current time. This operation isn't very efficient because it iterates over the entire dict, but given that the Memory backend is only meant to be used in local testing, I don't think it's worth optimizing further. For the RedisBackend, we now store idempotency keys in a sortedset instead of a set. The "score" of each key is its expiration time. This matches the Redis project's official recommendation on how to expire items in a set: https://github.com/redis/redis/issues/135#issuecomment-2361996 We then delete any items from the sorted set with a "score" lower than the current time. The Redis docs describe this a potentially slow operation: https://redis.io/commands/zremrangebyscore/ If we don't want to call this on every request, we could instead add some sort of random fuzzing to only occassionally call it. Or we could abandon using a sorted set entirely and just store the idempotency key as its own key in Redis using the standard Redis expiration functionality. This would mean each request would store potentially three keys (idempotency key, response key, status code key) instead of the current two (response key, status code key). Users who upgrade to a version of the library that includes this change will see a failure in Redis when it tries to use a sortedset operation on an existing set. At a minimum, this should be called out in a changelog. Additionally, we could change the default idempotency key name from `idempotency-key-keys` to something else to avoid a collision. --- .../backends/base.py | 8 +++++ .../backends/memory.py | 22 +++++++++++--- .../backends/redis.py | 15 ++++++++-- idempotency_header_middleware/middleware.py | 1 + tests/conftest.py | 2 ++ tests/test_backends.py | 29 ++++++++++++++++--- tests/test_middleware.py | 4 +-- 7 files changed, 67 insertions(+), 14 deletions(-) diff --git a/idempotency_header_middleware/backends/base.py b/idempotency_header_middleware/backends/base.py index 0c0aade..ebad861 100644 --- a/idempotency_header_middleware/backends/base.py +++ b/idempotency_header_middleware/backends/base.py @@ -45,3 +45,11 @@ async def clear_idempotency_key(self, idempotency_key: str) -> None: key stored in 'store_idempotency_key'. """ ... + + @abstractmethod + async def expire_idempotency_keys(self) -> None: + """ + Remove any expired idempotency keys to avoid returning 409s + after the response expires. + """ + ... diff --git a/idempotency_header_middleware/backends/memory.py b/idempotency_header_middleware/backends/memory.py index a162b8a..0c3f52b 100644 --- a/idempotency_header_middleware/backends/memory.py +++ b/idempotency_header_middleware/backends/memory.py @@ -22,7 +22,7 @@ class MemoryBackend(Backend): expiry: Optional[int] = 60 * 60 * 24 response_store: Dict[str, Dict[str, Any]] = field(default_factory=dict) - keys: Set[str] = field(default_factory=set) + idempotency_keys: Dict[str, Optional[int]] = field(default_factory=dict) async def get_stored_response(self, idempotency_key: str) -> Optional[JSONResponse]: """ @@ -54,14 +54,28 @@ async def store_idempotency_key(self, idempotency_key: str) -> bool: """ Store an idempotency key header value in a set. """ - if idempotency_key in self.keys: + if idempotency_key in self.idempotency_keys.keys(): return True - self.keys.add(idempotency_key) + self.idempotency_keys[idempotency_key] = time.time() + self.expiry if self.expiry else None return False async def clear_idempotency_key(self, idempotency_key: str) -> None: """ Remove an idempotency header value from the set. """ - self.keys.remove(idempotency_key) + del self.idempotency_keys[idempotency_key] + + async def expire_idempotency_keys(self) -> None: + """ + Remove any expired idempotency keys to avoid returning 409s + after the response expires. + """ + if not self.expiry: + return + + now = time.time() + for idempotency_key in list(self.idempotency_keys): + if expiry := self.idempotency_keys.get(idempotency_key): + if expiry <= now: + del self.idempotency_keys[idempotency_key] diff --git a/idempotency_header_middleware/backends/redis.py b/idempotency_header_middleware/backends/redis.py index 46bba0a..f2c7e71 100644 --- a/idempotency_header_middleware/backends/redis.py +++ b/idempotency_header_middleware/backends/redis.py @@ -1,3 +1,4 @@ +import time import json from dataclasses import dataclass from typing import Optional, Tuple @@ -55,12 +56,20 @@ async def store_response_data(self, idempotency_key: str, payload: dict, status_ async def store_idempotency_key(self, idempotency_key: str) -> bool: """ - Store an idempotency key header value in a set. + Store an idempotency key header value in a sortedset. """ - return not bool(await self.redis.sadd(self.KEYS_KEY, idempotency_key)) + return not bool(await self.redis.zadd(self.KEYS_KEY, {idempotency_key: time.time() + self.expiry},)) async def clear_idempotency_key(self, idempotency_key: str) -> None: """ Remove an idempotency header value from the set. """ - await self.redis.srem(self.KEYS_KEY, idempotency_key) + await self.redis.zrem(self.KEYS_KEY, idempotency_key) + + async def expire_idempotency_keys(self) -> None: + """ + Remove any expired idempotency keys to avoid returning 409s + after the response expires. + """ + if self.expiry: + await self.redis.zremrangebyscore(self.KEYS_KEY, '-inf', time.time()) diff --git a/idempotency_header_middleware/middleware.py b/idempotency_header_middleware/middleware.py index f4d2f52..dcd2bf2 100644 --- a/idempotency_header_middleware/middleware.py +++ b/idempotency_header_middleware/middleware.py @@ -49,6 +49,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Union[JS response = JSONResponse(payload, 422) return await response(scope, receive, send) + await self.backend.expire_idempotency_keys() if stored_response := await self.backend.get_stored_response(idempotency_key): stored_response.headers[self.replay_header_key] = 'true' return await stored_response(scope, receive, send) diff --git a/tests/conftest.py b/tests/conftest.py index e97847b..5462bad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,8 @@ logger = logging.getLogger(__name__) +pytestmark = pytest.mark.asyncio + @pytest.fixture(autouse=True, scope='session') def _configure_logging(): diff --git a/tests/test_backends.py b/tests/test_backends.py index 0c98fdc..2473e7b 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -16,6 +16,7 @@ 'store_response_data', 'store_idempotency_key', 'clear_idempotency_key', + 'expire_idempotency_keys', ] @@ -26,11 +27,13 @@ def test_base_backend(): redis = fakeredis.aioredis.FakeRedis(decode_responses=True) +backends = {"redis": RedisBackend(redis), "memory": MemoryBackend()} - -@pytest.mark.parametrize('backend', [RedisBackend(redis, expiry=1), MemoryBackend(expiry=1)]) -async def test_backend(backend: Backend): +@pytest.mark.parametrize('backend', backends.values(), ids=backends.keys()) +@pytest.mark.parametrize('expiry', [0, 1]) +async def test_backend(backend: Backend, expiry: int): assert issubclass(backend.__class__, Backend) + backend.expiry = expiry # Test setting and clearing key id_ = str(uuid4()) @@ -52,4 +55,22 @@ async def test_backend(backend: Backend): # Test fetching data after expiry await backend.store_response_data(id_, dummy_response, 201) await asyncio.sleep(1) - assert (await backend.get_stored_response(id_)) is None + stored_response = await backend.get_stored_response(id_) + if expiry: + assert stored_response is None + else: + assert stored_response is not None + + # Test storing idempotency key after expiry + id_ = str(uuid4()) + already_existed = await backend.store_idempotency_key(id_) + assert already_existed is False + already_existed = await backend.store_idempotency_key(id_) + assert already_existed is True + await asyncio.sleep(1) + await backend.expire_idempotency_keys() + already_existed = await backend.store_idempotency_key(id_) + if expiry: + assert already_existed is False + else: + assert already_existed is True diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 646165a..68d2519 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -5,9 +5,7 @@ import pytest from httpx import AsyncClient, Response -from tests.conftest import app, dummy_response - -pytestmark = pytest.mark.asyncio +from tests.conftest import app, dummy_response, pytestmark http_call = Callable[..., Awaitable[Response]] From 3cb5445cc9016deb6f50dabbe2aa98d8bff5ee93 Mon Sep 17 00:00:00 2001 From: Jordan Sanders Date: Wed, 4 Jan 2023 10:45:05 -0600 Subject: [PATCH 4/4] Fix lint errors --- idempotency_header_middleware/backends/memory.py | 13 ++++++------- idempotency_header_middleware/backends/redis.py | 2 +- tests/test_backends.py | 4 +--- tests/test_middleware.py | 3 +-- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/idempotency_header_middleware/backends/memory.py b/idempotency_header_middleware/backends/memory.py index 0c3f52b..6c29bc6 100644 --- a/idempotency_header_middleware/backends/memory.py +++ b/idempotency_header_middleware/backends/memory.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Set +from typing import Any, Dict, Optional from starlette.responses import JSONResponse @@ -22,7 +22,7 @@ class MemoryBackend(Backend): expiry: Optional[int] = 60 * 60 * 24 response_store: Dict[str, Dict[str, Any]] = field(default_factory=dict) - idempotency_keys: Dict[str, Optional[int]] = field(default_factory=dict) + idempotency_keys: Dict[str, Optional[float]] = field(default_factory=dict) async def get_stored_response(self, idempotency_key: str) -> Optional[JSONResponse]: """ @@ -54,10 +54,10 @@ async def store_idempotency_key(self, idempotency_key: str) -> bool: """ Store an idempotency key header value in a set. """ - if idempotency_key in self.idempotency_keys.keys(): + if idempotency_key in self.idempotency_keys: return True - self.idempotency_keys[idempotency_key] = time.time() + self.expiry if self.expiry else None + self.idempotency_keys[idempotency_key] = time.time() + float(self.expiry or 0) if self.expiry else None return False async def clear_idempotency_key(self, idempotency_key: str) -> None: @@ -76,6 +76,5 @@ async def expire_idempotency_keys(self) -> None: now = time.time() for idempotency_key in list(self.idempotency_keys): - if expiry := self.idempotency_keys.get(idempotency_key): - if expiry <= now: - del self.idempotency_keys[idempotency_key] + if (expiry := self.idempotency_keys.get(idempotency_key)) and expiry <= now: + del self.idempotency_keys[idempotency_key] diff --git a/idempotency_header_middleware/backends/redis.py b/idempotency_header_middleware/backends/redis.py index f2c7e71..bbd89d4 100644 --- a/idempotency_header_middleware/backends/redis.py +++ b/idempotency_header_middleware/backends/redis.py @@ -58,7 +58,7 @@ async def store_idempotency_key(self, idempotency_key: str) -> bool: """ Store an idempotency key header value in a sortedset. """ - return not bool(await self.redis.zadd(self.KEYS_KEY, {idempotency_key: time.time() + self.expiry},)) + return not bool(await self.redis.zadd(self.KEYS_KEY, {idempotency_key: time.time() + float(self.expiry or 0)},)) async def clear_idempotency_key(self, idempotency_key: str) -> None: """ diff --git a/tests/test_backends.py b/tests/test_backends.py index 2473e7b..162e7ec 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -7,9 +7,7 @@ from idempotency_header_middleware.backends.base import Backend from idempotency_header_middleware.backends.memory import MemoryBackend from idempotency_header_middleware.backends.redis import RedisBackend -from tests.conftest import dummy_response - -pytestmark = pytest.mark.asyncio +from tests.conftest import dummy_response, pytestmark # noqa: F401 base_methods = [ 'get_stored_response', diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 68d2519..4ddbe80 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -5,8 +5,7 @@ import pytest from httpx import AsyncClient, Response -from tests.conftest import app, dummy_response, pytestmark - +from tests.conftest import app, dummy_response, pytestmark # noqa: F401 http_call = Callable[..., Awaitable[Response]]