Skip to content

Commit

Permalink
allow passing a custom ArqRedis class to create_pool
Browse files Browse the repository at this point in the history
  • Loading branch information
jvllmr committed Jan 19, 2025
1 parent 7a911f3 commit c1ce72c
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, TypeVar, Union, cast
from urllib.parse import parse_qs, urlparse
from uuid import uuid4

Expand Down Expand Up @@ -217,6 +217,9 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])


TArqRedis = TypeVar('TArqRedis', bound=ArqRedis)


async def create_pool(
settings_: Optional[RedisSettings] = None,
*,
Expand All @@ -225,7 +228,8 @@ async def create_pool(
job_deserializer: Optional[Deserializer] = None,
default_queue_name: str = default_queue_name,
expires_extra_ms: int = expires_extra_ms,
) -> ArqRedis:
arq_redis_cls: Type[TArqRedis] = ArqRedis, # type: ignore[assignment]
) -> TArqRedis:
"""
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
Expand All @@ -238,19 +242,19 @@ async def create_pool(

if settings.sentinel:

def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
def pool_factory(*args: Any, **kwargs: Any) -> TArqRedis:
client = Sentinel( # type: ignore[misc]
*args,
sentinels=settings.host,
ssl=settings.ssl,
**kwargs,
)
redis = client.master_for(settings.sentinel_master, redis_class=ArqRedis)
return cast(ArqRedis, redis)
redis = client.master_for(settings.sentinel_master, redis_class=arq_redis_cls)
return cast(TArqRedis, redis)

else:
pool_factory = functools.partial(
ArqRedis,
arq_redis_cls,
host=settings.host,
port=settings.port,
unix_socket_path=settings.unix_socket_path,
Expand Down Expand Up @@ -312,8 +316,5 @@ async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any])
clients_connected = info_clients.get('connected_clients', '?')

log_func(
f'redis_version={redis_version} '
f'mem_usage={mem_usage} '
f'clients_connected={clients_connected} '
f'db_keys={key_count}'
f'redis_version={redis_version} mem_usage={mem_usage} clients_connected={clients_connected} db_keys={key_count}'
)

0 comments on commit c1ce72c

Please sign in to comment.