From 6808469f9823739488a570c7789e8fbb9fa86237 Mon Sep 17 00:00:00 2001 From: Erle Carrara Date: Thu, 25 Apr 2024 11:53:11 -0300 Subject: [PATCH] Add listener callback to `SimpleTokenBucket` (#4) --- pyproject.toml | 2 +- simple_token_bucket/__init__.py | 15 ++++++++++++++- tests/test_redis_backend.py | 9 +++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d981e2b..2e5b5ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ path = "simple_token_bucket/__init__.py" dependencies = [ "pytest", "pytest-cov", - "fakeredis", + "fakeredis<=2.21.3", ] [[tool.hatch.envs.test.matrix]] diff --git a/simple_token_bucket/__init__.py b/simple_token_bucket/__init__.py index 09a865f..c9584ad 100644 --- a/simple_token_bucket/__init__.py +++ b/simple_token_bucket/__init__.py @@ -1,11 +1,20 @@ __version__ = "0.1.2" +from typing import Callable, Optional + from .backends import Backend class SimpleTokenBucket: - def __init__(self, name: str, bucket_size: int, refresh_interval: int, backend: Backend): + def __init__( + self, + name: str, + bucket_size: int, + refresh_interval: int, + backend: Backend, + listener: Optional[Callable[[int], None]] = None, + ): """Create a new SimpleTokenBucket. Every `refresh_interval` the token bucket is refresh and `bucket_size` @@ -21,6 +30,7 @@ def __init__(self, name: str, bucket_size: int, refresh_interval: int, backend: self.bucket_size = bucket_size self.refresh_interval = refresh_interval self._backend = backend + self._listener_callback = listener def try_get_token(self, raises=True) -> bool: available_tokens, ttl = self._backend.get_token( @@ -29,6 +39,9 @@ def try_get_token(self, raises=True) -> bool: refresh_interval=self.refresh_interval, ) + if self._listener_callback is not None: + self._listener_callback(available_tokens) + if available_tokens <= 0: if raises: raise NotEnoughTokens(remaining_seconds=ttl) diff --git a/tests/test_redis_backend.py b/tests/test_redis_backend.py index 9da93f2..c01c1e4 100644 --- a/tests/test_redis_backend.py +++ b/tests/test_redis_backend.py @@ -1,7 +1,9 @@ -import pytest +from unittest import mock + import fakeredis +import pytest -from simple_token_bucket import SimpleTokenBucket, NotEnoughTokens +from simple_token_bucket import NotEnoughTokens, SimpleTokenBucket from simple_token_bucket.backends.redis import RedisBackend @@ -11,13 +13,16 @@ def redis_client(): def test_redis_backend_ok(redis_client): + listener = mock.MagicMock() token_bucket = SimpleTokenBucket( name="test_ok", bucket_size=3, refresh_interval=60, backend=RedisBackend(redis_client), + listener=listener, ) assert token_bucket.try_get_token() is True + listener.assert_called_once_with(3) def test_redis_backend_not_enough_tokens_raises(redis_client):