From 04fe59d1dbe98570cf2ae51506023619c4824d59 Mon Sep 17 00:00:00 2001 From: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> Date: Fri, 2 Aug 2024 22:59:44 +0800 Subject: [PATCH] =?UTF-8?q?`RateLimiter`=20=E9=87=8D=E6=9E=84=20(#707)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/qianfan/resources/rate_limiter.py | 135 ++++++++++++++++------- 1 file changed, 98 insertions(+), 37 deletions(-) diff --git a/python/qianfan/resources/rate_limiter.py b/python/qianfan/resources/rate_limiter.py index d91b933d..da5499b0 100644 --- a/python/qianfan/resources/rate_limiter.py +++ b/python/qianfan/resources/rate_limiter.py @@ -18,6 +18,7 @@ import asyncio import threading import time +from queue import Queue from types import TracebackType from typing import Any, Optional, Type @@ -138,11 +139,12 @@ async def async_reset_once(self, rpm: float) -> None: self._async_reset_once_lock.release() return + self._has_been_reset = True + og_rpm = self._get_og_rpm() # 如果新旧值一致则不需要操作 if og_rpm == rpm: - self._has_been_reset = True self._async_reset_once_lock.release() return @@ -153,14 +155,12 @@ async def async_reset_once(self, rpm: float) -> None: # 如果重置为 0 则直接关闭 if rpm == 0: self.is_closed = True - self._has_been_reset = True self._async_reset_once_lock.release() return # 重置 - self._reset_internal_rate_limiter(rpm) + await self._async_reset_internal_rate_limiter(rpm) - self._has_been_reset = True self._async_reset_once_lock.release() def reset_once(self, rpm: float) -> None: @@ -175,12 +175,13 @@ def reset_once(self, rpm: float) -> None: self._reset_once_lock.release() return + self._has_been_reset = True + self._reset_once_lock.release() + og_rpm = self._get_og_rpm() # 如果新旧值一致则不需要操作 if og_rpm == rpm: - self._has_been_reset = True - self._reset_once_lock.release() return # 取最小的那个,如果是关闭的则直接取重置的 @@ -190,16 +191,11 @@ def reset_once(self, rpm: float) -> None: # 如果重置为 0 则直接关闭 if rpm == 0: self.is_closed = True - self._has_been_reset = True - self._reset_once_lock.release() return # 重置 self._reset_internal_rate_limiter(rpm) - self._has_been_reset = True - self._reset_once_lock.release() - def _reset_internal_rate_limiter(self, rpm: float) -> None: # 记录一下新值 if self.is_closed or self._is_rpm: @@ -209,26 +205,55 @@ def _reset_internal_rate_limiter(self, rpm: float) -> None: self._is_rpm = False self._new_query_per_second = rpm / 60 + # 重置 + self._sync_reset(rpm * (1 - self._buffer_ratio)) self.is_closed = False + async def _async_reset_internal_rate_limiter(self, rpm: float) -> None: + # 记录一下新值 + if self.is_closed or self._is_rpm: + self._is_rpm = True + self._new_request_per_minute = rpm + else: + self._is_rpm = False + self._new_query_per_second = rpm / 60 + # 重置 - rpm *= 1 - self._buffer_ratio + await self._async_reset(rpm * (1 - self._buffer_ratio)) + self.is_closed = False + def _sync_reset(self, rpm: float) -> None: if self._is_rpm: - self._internal_qp10s_rate_limiter = RateLimiter(rpm / 6, 10) - self._internal_rpm_rate_limiter = RateLimiter(rpm, 60) + if hasattr(self, "_internal_qp10s_rate_limiter"): + self._internal_qp10s_rate_limiter.reset(rpm / 6, 10) + self._internal_rpm_rate_limiter.reset(rpm, 60) + else: + self._internal_qp10s_rate_limiter = RateLimiter(rpm / 6, 10) + self._internal_rpm_rate_limiter = RateLimiter(rpm, 60) else: - self._internal_qps_rate_limiter = RateLimiter(rpm / 60) + if hasattr(self, "_internal_qps_rate_limiter"): + self._internal_qps_rate_limiter.reset(rpm / 60) + else: + self._internal_qps_rate_limiter = RateLimiter(rpm / 60) + + async def _async_reset(self, rpm: float) -> None: + if self._is_rpm: + if hasattr(self, "_internal_qp10s_rate_limiter"): + await self._internal_qp10s_rate_limiter.async_reset(rpm / 6, 10) + await self._internal_rpm_rate_limiter.async_reset(rpm, 60) + else: + self._internal_qp10s_rate_limiter = RateLimiter(rpm / 6, 10) + self._internal_rpm_rate_limiter = RateLimiter(rpm, 60) + else: + if hasattr(self, "_internal_qps_rate_limiter"): + await self._internal_qps_rate_limiter.async_reset(rpm / 60) + else: + self._internal_qps_rate_limiter = RateLimiter(rpm / 60) def __enter__(self) -> None: if self.is_closed: return - if not self._has_been_reset: - self._reset_once_lock.acquire() - if self._has_been_reset: - self._reset_once_lock.release() - if self._is_rpm: with self._internal_rpm_rate_limiter: with self._internal_qp10s_rate_limiter: @@ -237,9 +262,6 @@ def __enter__(self) -> None: with self._internal_qps_rate_limiter: ... - if not self._has_been_reset: - self._reset_once_lock.release() - def __exit__( self, exc_type: Optional[Type[BaseException]], @@ -252,11 +274,6 @@ async def __aenter__(self) -> None: if self.is_closed: return - if not self._has_been_reset: - await self._async_reset_once_lock.acquire() - if self._has_been_reset: - self._async_reset_once_lock.release() - if self._is_rpm: async with self._internal_rpm_rate_limiter: async with self._internal_qp10s_rate_limiter: @@ -265,9 +282,6 @@ async def __aenter__(self) -> None: async with self._internal_qps_rate_limiter: ... - if not self._has_been_reset: - self._async_reset_once_lock.release() - async def __aexit__( self, exc_type: Optional[Type[BaseException]], @@ -284,6 +298,11 @@ class RateLimiter: we recommend only use one of two method within single rate limiter at same time """ + class _AcquireTask: + def __init__(self, condition: threading.Condition, amount: float): + self.condition = condition + self.amount = amount + class _SyncLimiter: def __init__( self, @@ -313,6 +332,9 @@ def __init__( self._token_count = 0.0 self._last_leak_timestamp = time.time() self._sync_lock = threading.Lock() + self._condition_queue: Queue[RateLimiter._AcquireTask] = Queue() + self._working_thread = threading.Thread(target=self._worker, daemon=True) + self._working_thread.start() def _leak(self) -> None: timestamp = time.time() @@ -323,17 +345,40 @@ def _leak(self) -> None: self._token_count + delta * self._query_per_second, ) + def _worker(self) -> None: + while True: + task = self._condition_queue.get(True) + amount = task.amount + while True: + with self._sync_lock: + self._leak() + if self._token_count >= amount: + self._token_count -= amount + break + time.sleep((amount - self._token_count) / self._query_per_second) + + with task.condition: + task.condition.notify() + def acquire(self, amount: float = 1) -> None: if amount > self._query_per_period: raise ValueError("Can't acquire more than the maximum capacity") + request_condition = threading.Condition() + self._condition_queue.put( + RateLimiter._AcquireTask(request_condition, amount) + ) + with request_condition: + request_condition.wait() + + def reset( + self, query_per_period: float = 1, period_in_second: float = 1 + ) -> None: with self._sync_lock: - while True: - self._leak() - if self._token_count >= amount: - self._token_count -= amount - return - time.sleep((amount - self._token_count) / self._query_per_second) + self._query_per_period = query_per_period + self._period_in_second = period_in_second + self._query_per_second = query_per_period / period_in_second + self._token_count = min(self._query_per_period, self._token_count) def __enter__(self) -> None: """ @@ -396,6 +441,22 @@ def _warmup_procedure() -> None: warmup_thread.start() warmup_thread.join() + def reset( + self, + query_per_period: float = 1, + period_in_second: float = 1, + ) -> None: + self._sync_limiter.reset(query_per_period, period_in_second) + + async def async_reset( + self, + query_per_period: float = 1, + period_in_second: float = 1, + ) -> None: + self._async_limiter.max_rate = query_per_period + self._async_limiter.time_period = period_in_second + self._async_limiter._rate_per_sec = query_per_period / period_in_second + def acquire(self, amount: float) -> None: if self._check_is_closed(): return