Skip to content

Commit

Permalink
Add listener callback to SimpleTokenBucket (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
ecarrara authored Apr 25, 2024
1 parent 87ef310 commit 6808469
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ path = "simple_token_bucket/__init__.py"
dependencies = [
"pytest",
"pytest-cov",
"fakeredis",
"fakeredis<=2.21.3",
]

[[tool.hatch.envs.test.matrix]]
Expand Down
15 changes: 14 additions & 1 deletion simple_token_bucket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
__version__ = "0.1.2"


from typing import Callable, Optional

from .backends import Backend


class SimpleTokenBucket:
def __init__(self, name: str, bucket_size: int, refresh_interval: int, backend: Backend):
def __init__(
self,
name: str,
bucket_size: int,
refresh_interval: int,
backend: Backend,
listener: Optional[Callable[[int], None]] = None,
):
"""Create a new SimpleTokenBucket.
Every `refresh_interval` the token bucket is refresh and `bucket_size`
Expand All @@ -21,6 +30,7 @@ def __init__(self, name: str, bucket_size: int, refresh_interval: int, backend:
self.bucket_size = bucket_size
self.refresh_interval = refresh_interval
self._backend = backend
self._listener_callback = listener

def try_get_token(self, raises=True) -> bool:
available_tokens, ttl = self._backend.get_token(
Expand All @@ -29,6 +39,9 @@ def try_get_token(self, raises=True) -> bool:
refresh_interval=self.refresh_interval,
)

if self._listener_callback is not None:
self._listener_callback(available_tokens)

if available_tokens <= 0:
if raises:
raise NotEnoughTokens(remaining_seconds=ttl)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_redis_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from unittest import mock

import fakeredis
import pytest

from simple_token_bucket import SimpleTokenBucket, NotEnoughTokens
from simple_token_bucket import NotEnoughTokens, SimpleTokenBucket
from simple_token_bucket.backends.redis import RedisBackend


Expand All @@ -11,13 +13,16 @@ def redis_client():


def test_redis_backend_ok(redis_client):
listener = mock.MagicMock()
token_bucket = SimpleTokenBucket(
name="test_ok",
bucket_size=3,
refresh_interval=60,
backend=RedisBackend(redis_client),
listener=listener,
)
assert token_bucket.try_get_token() is True
listener.assert_called_once_with(3)


def test_redis_backend_not_enough_tokens_raises(redis_client):
Expand Down

0 comments on commit 6808469

Please sign in to comment.