Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test expiration #11

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions idempotency_header_middleware/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,11 @@ async def clear_idempotency_key(self, idempotency_key: str) -> None:
key stored in 'store_idempotency_key'.
"""
...

@abstractmethod
async def expire_idempotency_keys(self) -> None:
"""
Remove any expired idempotency keys to avoid returning 409s
after the response expires.
"""
...
22 changes: 18 additions & 4 deletions idempotency_header_middleware/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MemoryBackend(Backend):
expiry: Optional[int] = 60 * 60 * 24

response_store: Dict[str, Dict[str, Any]] = field(default_factory=dict)
keys: Set[str] = field(default_factory=set)
idempotency_keys: Dict[str, Optional[int]] = field(default_factory=dict)

async def get_stored_response(self, idempotency_key: str) -> Optional[JSONResponse]:
"""
Expand Down Expand Up @@ -54,14 +54,28 @@ async def store_idempotency_key(self, idempotency_key: str) -> bool:
"""
Store an idempotency key header value in a set.
"""
if idempotency_key in self.keys:
if idempotency_key in self.idempotency_keys.keys():
return True

self.keys.add(idempotency_key)
self.idempotency_keys[idempotency_key] = time.time() + self.expiry if self.expiry else None
return False

async def clear_idempotency_key(self, idempotency_key: str) -> None:
"""
Remove an idempotency header value from the set.
"""
self.keys.remove(idempotency_key)
del self.idempotency_keys[idempotency_key]

async def expire_idempotency_keys(self) -> None:
"""
Remove any expired idempotency keys to avoid returning 409s
after the response expires.
"""
if not self.expiry:
return

now = time.time()
for idempotency_key in list(self.idempotency_keys):
if expiry := self.idempotency_keys.get(idempotency_key):
if expiry <= now:
del self.idempotency_keys[idempotency_key]
15 changes: 12 additions & 3 deletions idempotency_header_middleware/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import json
from dataclasses import dataclass
from typing import Optional, Tuple
Expand Down Expand Up @@ -55,12 +56,20 @@ async def store_response_data(self, idempotency_key: str, payload: dict, status_

async def store_idempotency_key(self, idempotency_key: str) -> bool:
"""
Store an idempotency key header value in a set.
Store an idempotency key header value in a sortedset.
"""
return not bool(await self.redis.sadd(self.KEYS_KEY, idempotency_key))
return not bool(await self.redis.zadd(self.KEYS_KEY, {idempotency_key: time.time() + self.expiry},))

async def clear_idempotency_key(self, idempotency_key: str) -> None:
"""
Remove an idempotency header value from the set.
"""
await self.redis.srem(self.KEYS_KEY, idempotency_key)
await self.redis.zrem(self.KEYS_KEY, idempotency_key)

async def expire_idempotency_keys(self) -> None:
"""
Remove any expired idempotency keys to avoid returning 409s
after the response expires.
"""
if self.expiry:
await self.redis.zremrangebyscore(self.KEYS_KEY, '-inf', time.time())
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned it in the main PR thread, but here's where it might make sense to instead store each idempotency key as its own key in Redis and avoid the potentially expensive ZREMRANGEBYSCORE call entirely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very concerned with this being a slow operation, as a slow redis operation is probably much faster for this type of use-case than making just one additional round-trip to the cache. At least that's what I'd expect.

If we wanted to store them separately and save the round-trips, I suppose we could use pipelining or a Lua script to perform the operations in one go. I'll leave that up to you 🙂

1 change: 1 addition & 0 deletions idempotency_header_middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Union[JS
response = JSONResponse(payload, 422)
return await response(scope, receive, send)

await self.backend.expire_idempotency_keys()
if stored_response := await self.backend.get_stored_response(idempotency_key):
stored_response.headers[self.replay_header_key] = 'true'
return await stored_response(scope, receive, send)
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

logger = logging.getLogger(__name__)

pytestmark = pytest.mark.asyncio


@pytest.fixture(autouse=True, scope='session')
def _configure_logging():
Expand Down
29 changes: 25 additions & 4 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'store_response_data',
'store_idempotency_key',
'clear_idempotency_key',
'expire_idempotency_keys',
]


Expand All @@ -26,11 +27,13 @@ def test_base_backend():


redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
backends = {"redis": RedisBackend(redis), "memory": MemoryBackend()}


@pytest.mark.parametrize('backend', [RedisBackend(redis, expiry=1), MemoryBackend(expiry=1)])
async def test_backend(backend: Backend):
@pytest.mark.parametrize('backend', backends.values(), ids=backends.keys())
@pytest.mark.parametrize('expiry', [0, 1])
async def test_backend(backend: Backend, expiry: int):
assert issubclass(backend.__class__, Backend)
backend.expiry = expiry

# Test setting and clearing key
id_ = str(uuid4())
Expand All @@ -52,4 +55,22 @@ async def test_backend(backend: Backend):
# Test fetching data after expiry
await backend.store_response_data(id_, dummy_response, 201)
await asyncio.sleep(1)
assert (await backend.get_stored_response(id_)) is None
stored_response = await backend.get_stored_response(id_)
if expiry:
assert stored_response is None
else:
assert stored_response is not None

# Test storing idempotency key after expiry
id_ = str(uuid4())
already_existed = await backend.store_idempotency_key(id_)
assert already_existed is False
already_existed = await backend.store_idempotency_key(id_)
assert already_existed is True
await asyncio.sleep(1)
await backend.expire_idempotency_keys()
already_existed = await backend.store_idempotency_key(id_)
if expiry:
assert already_existed is False
else:
assert already_existed is True
4 changes: 1 addition & 3 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import pytest
from httpx import AsyncClient, Response

from tests.conftest import app, dummy_response

pytestmark = pytest.mark.asyncio
from tests.conftest import app, dummy_response, pytestmark

http_call = Callable[..., Awaitable[Response]]

Expand Down