Skip to content

Commit

Permalink
RateLimiter 重构 (#707)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiichi-Origami authored Aug 2, 2024
1 parent 3a80c11 commit 04fe59d
Showing 1 changed file with 98 additions and 37 deletions.
135 changes: 98 additions & 37 deletions python/qianfan/resources/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import threading
import time
from queue import Queue
from types import TracebackType
from typing import Any, Optional, Type

Expand Down Expand Up @@ -138,11 +139,12 @@ async def async_reset_once(self, rpm: float) -> None:
self._async_reset_once_lock.release()
return

self._has_been_reset = True

og_rpm = self._get_og_rpm()

# 如果新旧值一致则不需要操作
if og_rpm == rpm:
self._has_been_reset = True
self._async_reset_once_lock.release()
return

Expand All @@ -153,14 +155,12 @@ async def async_reset_once(self, rpm: float) -> None:
# 如果重置为 0 则直接关闭
if rpm == 0:
self.is_closed = True
self._has_been_reset = True
self._async_reset_once_lock.release()
return

# 重置
self._reset_internal_rate_limiter(rpm)
await self._async_reset_internal_rate_limiter(rpm)

self._has_been_reset = True
self._async_reset_once_lock.release()

def reset_once(self, rpm: float) -> None:
Expand All @@ -175,12 +175,13 @@ def reset_once(self, rpm: float) -> None:
self._reset_once_lock.release()
return

self._has_been_reset = True
self._reset_once_lock.release()

og_rpm = self._get_og_rpm()

# 如果新旧值一致则不需要操作
if og_rpm == rpm:
self._has_been_reset = True
self._reset_once_lock.release()
return

# 取最小的那个,如果是关闭的则直接取重置的
Expand All @@ -190,16 +191,11 @@ def reset_once(self, rpm: float) -> None:
# 如果重置为 0 则直接关闭
if rpm == 0:
self.is_closed = True
self._has_been_reset = True
self._reset_once_lock.release()
return

# 重置
self._reset_internal_rate_limiter(rpm)

self._has_been_reset = True
self._reset_once_lock.release()

def _reset_internal_rate_limiter(self, rpm: float) -> None:
# 记录一下新值
if self.is_closed or self._is_rpm:
Expand All @@ -209,26 +205,55 @@ def _reset_internal_rate_limiter(self, rpm: float) -> None:
self._is_rpm = False
self._new_query_per_second = rpm / 60

# 重置
self._sync_reset(rpm * (1 - self._buffer_ratio))
self.is_closed = False

async def _async_reset_internal_rate_limiter(self, rpm: float) -> None:
# 记录一下新值
if self.is_closed or self._is_rpm:
self._is_rpm = True
self._new_request_per_minute = rpm
else:
self._is_rpm = False
self._new_query_per_second = rpm / 60

# 重置
rpm *= 1 - self._buffer_ratio
await self._async_reset(rpm * (1 - self._buffer_ratio))
self.is_closed = False

def _sync_reset(self, rpm: float) -> None:
if self._is_rpm:
self._internal_qp10s_rate_limiter = RateLimiter(rpm / 6, 10)
self._internal_rpm_rate_limiter = RateLimiter(rpm, 60)
if hasattr(self, "_internal_qp10s_rate_limiter"):
self._internal_qp10s_rate_limiter.reset(rpm / 6, 10)
self._internal_rpm_rate_limiter.reset(rpm, 60)
else:
self._internal_qp10s_rate_limiter = RateLimiter(rpm / 6, 10)
self._internal_rpm_rate_limiter = RateLimiter(rpm, 60)
else:
self._internal_qps_rate_limiter = RateLimiter(rpm / 60)
if hasattr(self, "_internal_qps_rate_limiter"):
self._internal_qps_rate_limiter.reset(rpm / 60)
else:
self._internal_qps_rate_limiter = RateLimiter(rpm / 60)

async def _async_reset(self, rpm: float) -> None:
if self._is_rpm:
if hasattr(self, "_internal_qp10s_rate_limiter"):
await self._internal_qp10s_rate_limiter.async_reset(rpm / 6, 10)
await self._internal_rpm_rate_limiter.async_reset(rpm, 60)
else:
self._internal_qp10s_rate_limiter = RateLimiter(rpm / 6, 10)
self._internal_rpm_rate_limiter = RateLimiter(rpm, 60)
else:
if hasattr(self, "_internal_qps_rate_limiter"):
await self._internal_qps_rate_limiter.async_reset(rpm / 60)
else:
self._internal_qps_rate_limiter = RateLimiter(rpm / 60)

def __enter__(self) -> None:
if self.is_closed:
return

if not self._has_been_reset:
self._reset_once_lock.acquire()
if self._has_been_reset:
self._reset_once_lock.release()

if self._is_rpm:
with self._internal_rpm_rate_limiter:
with self._internal_qp10s_rate_limiter:
Expand All @@ -237,9 +262,6 @@ def __enter__(self) -> None:
with self._internal_qps_rate_limiter:
...

if not self._has_been_reset:
self._reset_once_lock.release()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
Expand All @@ -252,11 +274,6 @@ async def __aenter__(self) -> None:
if self.is_closed:
return

if not self._has_been_reset:
await self._async_reset_once_lock.acquire()
if self._has_been_reset:
self._async_reset_once_lock.release()

if self._is_rpm:
async with self._internal_rpm_rate_limiter:
async with self._internal_qp10s_rate_limiter:
Expand All @@ -265,9 +282,6 @@ async def __aenter__(self) -> None:
async with self._internal_qps_rate_limiter:
...

if not self._has_been_reset:
self._async_reset_once_lock.release()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
Expand All @@ -284,6 +298,11 @@ class RateLimiter:
we recommend only use one of two method within single rate limiter at same time
"""

class _AcquireTask:
def __init__(self, condition: threading.Condition, amount: float):
self.condition = condition
self.amount = amount

class _SyncLimiter:
def __init__(
self,
Expand Down Expand Up @@ -313,6 +332,9 @@ def __init__(
self._token_count = 0.0
self._last_leak_timestamp = time.time()
self._sync_lock = threading.Lock()
self._condition_queue: Queue[RateLimiter._AcquireTask] = Queue()
self._working_thread = threading.Thread(target=self._worker, daemon=True)
self._working_thread.start()

def _leak(self) -> None:
timestamp = time.time()
Expand All @@ -323,17 +345,40 @@ def _leak(self) -> None:
self._token_count + delta * self._query_per_second,
)

def _worker(self) -> None:
while True:
task = self._condition_queue.get(True)
amount = task.amount
while True:
with self._sync_lock:
self._leak()
if self._token_count >= amount:
self._token_count -= amount
break
time.sleep((amount - self._token_count) / self._query_per_second)

with task.condition:
task.condition.notify()

def acquire(self, amount: float = 1) -> None:
if amount > self._query_per_period:
raise ValueError("Can't acquire more than the maximum capacity")

request_condition = threading.Condition()
self._condition_queue.put(
RateLimiter._AcquireTask(request_condition, amount)
)
with request_condition:
request_condition.wait()

def reset(
self, query_per_period: float = 1, period_in_second: float = 1
) -> None:
with self._sync_lock:
while True:
self._leak()
if self._token_count >= amount:
self._token_count -= amount
return
time.sleep((amount - self._token_count) / self._query_per_second)
self._query_per_period = query_per_period
self._period_in_second = period_in_second
self._query_per_second = query_per_period / period_in_second
self._token_count = min(self._query_per_period, self._token_count)

def __enter__(self) -> None:
"""
Expand Down Expand Up @@ -396,6 +441,22 @@ def _warmup_procedure() -> None:
warmup_thread.start()
warmup_thread.join()

def reset(
self,
query_per_period: float = 1,
period_in_second: float = 1,
) -> None:
self._sync_limiter.reset(query_per_period, period_in_second)

async def async_reset(
self,
query_per_period: float = 1,
period_in_second: float = 1,
) -> None:
self._async_limiter.max_rate = query_per_period
self._async_limiter.time_period = period_in_second
self._async_limiter._rate_per_sec = query_per_period / period_in_second

def acquire(self, amount: float) -> None:
if self._check_is_closed():
return
Expand Down

0 comments on commit 04fe59d

Please sign in to comment.