Skip to content

Commit

Permalink
Expire idempotency keys
Browse files Browse the repository at this point in the history
Previously, idempotency keys were never being expired. This had two
consequences:

- once a response key expires, requests using the same idempotency
  key will return 409 forever
- the idempotency keys set grows without bound

There's a new `expire_idempotency_keys` abstract method that
implementations of Backend are responsible for implementing. The
middleware calls it on every request.

For the Memory backend, we now store idempotency keys in a dict intead
of a set. The value of each key is its expiration time.

We then iterate over all of the items in the dict and
remove the ones with an expiration earlier than the current time. This
operation isn't very efficient because it iterates over the entire dict,
but given that the Memory backend is only meant to be used in local
testing, I don't think it's worth optimizing further.

For the RedisBackend, we now store idempotency keys in a sortedset
instead of a set. The "score" of each key is its expiration time. This
matches the Redis project's official recommendation on how to expire
items in a set: redis/redis#135 (comment)

We then delete any items from the sorted set with a "score" lower than
the current time. The Redis docs describe this a potentially slow
operation: https://redis.io/commands/zremrangebyscore/

If we don't want to call this on every request, we could instead add
some sort of random fuzzing to only occassionally call it. Or we could
abandon using a sorted set entirely and just store the idempotency key
as its own key in Redis using the standard Redis expiration
functionality. This would mean each request would store potentially
three keys (idempotency key, response key, status code key) instead of
the current two (response key, status code key).

Users who upgrade to a version of the library that includes this change
will see a failure in Redis when it tries to use a sortedset operation
on an existing set. At a minimum, this should be called out in a
changelog. Additionally, we could change the default idempotency key
name from `idempotency-key-keys` to something else to avoid a collision.
  • Loading branch information
jmsanders committed Jan 4, 2023
1 parent e63463e commit 44eee09
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 14 deletions.
8 changes: 8 additions & 0 deletions idempotency_header_middleware/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,11 @@ async def clear_idempotency_key(self, idempotency_key: str) -> None:
key stored in 'store_idempotency_key'.
"""
...

@abstractmethod
async def expire_idempotency_keys(self) -> None:
"""
Remove any expired idempotency keys to avoid returning 409s
after the response expires.
"""
...
22 changes: 18 additions & 4 deletions idempotency_header_middleware/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MemoryBackend(Backend):
expiry: Optional[int] = 60 * 60 * 24

response_store: Dict[str, Dict[str, Any]] = field(default_factory=dict)
keys: Set[str] = field(default_factory=set)
idempotency_keys: Dict[str, Optional[int]] = field(default_factory=dict)

async def get_stored_response(self, idempotency_key: str) -> Optional[JSONResponse]:
"""
Expand Down Expand Up @@ -54,14 +54,28 @@ async def store_idempotency_key(self, idempotency_key: str) -> bool:
"""
Store an idempotency key header value in a set.
"""
if idempotency_key in self.keys:
if idempotency_key in self.idempotency_keys.keys():
return True

self.keys.add(idempotency_key)
self.idempotency_keys[idempotency_key] = time.time() + self.expiry if self.expiry else None
return False

async def clear_idempotency_key(self, idempotency_key: str) -> None:
"""
Remove an idempotency header value from the set.
"""
self.keys.remove(idempotency_key)
del self.idempotency_keys[idempotency_key]

async def expire_idempotency_keys(self) -> None:
"""
Remove any expired idempotency keys to avoid returning 409s
after the response expires.
"""
if not self.expiry:
return

now = time.time()
for idempotency_key in list(self.idempotency_keys):
if expiry := self.idempotency_keys.get(idempotency_key):
if expiry <= now:
del self.idempotency_keys[idempotency_key]
15 changes: 12 additions & 3 deletions idempotency_header_middleware/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import json
from dataclasses import dataclass
from typing import Optional, Tuple
Expand Down Expand Up @@ -55,12 +56,20 @@ async def store_response_data(self, idempotency_key: str, payload: dict, status_

async def store_idempotency_key(self, idempotency_key: str) -> bool:
"""
Store an idempotency key header value in a set.
Store an idempotency key header value in a sortedset.
"""
return not bool(await self.redis.sadd(self.KEYS_KEY, idempotency_key))
return not bool(await self.redis.zadd(self.KEYS_KEY, {idempotency_key: time.time() + self.expiry},))

async def clear_idempotency_key(self, idempotency_key: str) -> None:
"""
Remove an idempotency header value from the set.
"""
await self.redis.srem(self.KEYS_KEY, idempotency_key)
await self.redis.zrem(self.KEYS_KEY, idempotency_key)

async def expire_idempotency_keys(self) -> None:
"""
Remove any expired idempotency keys to avoid returning 409s
after the response expires.
"""
if self.expiry:
await self.redis.zremrangebyscore(self.KEYS_KEY, '-inf', time.time())
1 change: 1 addition & 0 deletions idempotency_header_middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Union[JS
response = JSONResponse(payload, 422)
return await response(scope, receive, send)

await self.backend.expire_idempotency_keys()
if stored_response := await self.backend.get_stored_response(idempotency_key):
stored_response.headers[self.replay_header_key] = 'true'
return await stored_response(scope, receive, send)
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

logger = logging.getLogger(__name__)

pytestmark = pytest.mark.asyncio


@pytest.fixture(autouse=True, scope='session')
def _configure_logging():
Expand Down
29 changes: 25 additions & 4 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'store_response_data',
'store_idempotency_key',
'clear_idempotency_key',
'expire_idempotency_keys',
]


Expand All @@ -26,11 +27,13 @@ def test_base_backend():


redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
backends = {"redis": RedisBackend(redis), "memory": MemoryBackend()}


@pytest.mark.parametrize('backend', [RedisBackend(redis, expiry=1), MemoryBackend(expiry=1)])
async def test_backend(backend: Backend):
@pytest.mark.parametrize('backend', backends.values(), ids=backends.keys())
@pytest.mark.parametrize('expiry', [0, 1])
async def test_backend(backend: Backend, expiry: int):
assert issubclass(backend.__class__, Backend)
backend.expiry = expiry

# Test setting and clearing key
id_ = str(uuid4())
Expand All @@ -52,4 +55,22 @@ async def test_backend(backend: Backend):
# Test fetching data after expiry
await backend.store_response_data(id_, dummy_response, 201)
await asyncio.sleep(1)
assert (await backend.get_stored_response(id_)) is None
stored_response = await backend.get_stored_response(id_)
if expiry:
assert stored_response is None
else:
assert stored_response is not None

# Test storing idempotency key after expiry
id_ = str(uuid4())
already_existed = await backend.store_idempotency_key(id_)
assert already_existed is False
already_existed = await backend.store_idempotency_key(id_)
assert already_existed is True
await asyncio.sleep(1)
await backend.expire_idempotency_keys()
already_existed = await backend.store_idempotency_key(id_)
if expiry:
assert already_existed is False
else:
assert already_existed is True
4 changes: 1 addition & 3 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import pytest
from httpx import AsyncClient, Response

from tests.conftest import app, dummy_response

pytestmark = pytest.mark.asyncio
from tests.conftest import app, dummy_response, pytestmark

http_call = Callable[..., Awaitable[Response]]

Expand Down

0 comments on commit 44eee09

Please sign in to comment.