From c54ca51bdd7844fbff7e1b6122f3ca06de2ed335 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Wed, 6 Sep 2023 13:06:07 +0200 Subject: [PATCH] setup: Upgrade thriftpy2 and replace aioredis with redis-py (#19) --- .pre-commit-config.yaml | 1 + examples/simple-client-thrift.py | 11 +-- examples/simple-client.py | 8 +- setup.py | 18 ++-- src/callosum/lower/dispatch_redis.py | 98 +++++++++++---------- src/callosum/lower/redis_common.py | 18 ++++ src/callosum/lower/rpc_redis.py | 125 +++++++++++++++------------ src/callosum/rpc/channel.py | 3 +- src/callosum/upper/thrift.py | 3 +- 9 files changed, 156 insertions(+), 129 deletions(-) create mode 100644 src/callosum/lower/redis_common.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f0a559..0f30287 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,5 +22,6 @@ repos: - id: mypy additional_dependencies: [ 'types-python-dateutil', + 'types-redis', 'attrs>=21.3', ] diff --git a/examples/simple-client-thrift.py b/examples/simple-client-thrift.py index 6070c79..0b3927d 100644 --- a/examples/simple-client-thrift.py +++ b/examples/simple-client-thrift.py @@ -5,14 +5,11 @@ import sys import textwrap -from async_timeout import timeout - -from callosum.rpc import Peer, RPCUserError -from callosum.serialize import noop_serializer, noop_deserializer +import thriftpy2 as thriftpy from callosum.lower.zeromq import ZeroMQAddress, ZeroMQRPCTransport +from callosum.rpc import Peer, RPCUserError +from callosum.serialize import noop_deserializer, noop_serializer from callosum.upper.thrift import ThriftClientAdaptor -import thriftpy2 as thriftpy - simple_thrift = thriftpy.load( str(pathlib.Path(__file__).parent / "simple.thrift"), module_name="simple_thrift" @@ -47,7 +44,7 @@ async def call() -> None: print(textwrap.indent(e.traceback, prefix="| ")) try: - with timeout(0.5): + async with asyncio.timeout(0.5): await peer.invoke("simple", adaptor.long_delay()) except asyncio.TimeoutError: print( diff --git a/examples/simple-client.py b/examples/simple-client.py index 084d92d..a6abd08 100644 --- a/examples/simple-client.py +++ b/examples/simple-client.py @@ -1,16 +1,14 @@ import asyncio import json -import random import os +import random import secrets import sys import textwrap import traceback -from async_timeout import timeout - -from callosum.rpc import Peer, RPCUserError from callosum.lower.zeromq import ZeroMQAddress, ZeroMQRPCTransport +from callosum.rpc import Peer, RPCUserError async def test_simple(peer, initial_delay: float = 0): @@ -41,7 +39,7 @@ async def test_simple(peer, initial_delay: float = 0): async def test_timeout(peer): try: - with timeout(0.5): + async with asyncio.timeout(0.5): await peer.invoke( "long_delay", { diff --git a/setup.py b/setup.py index 85002e7..d7af06a 100644 --- a/setup.py +++ b/setup.py @@ -16,12 +16,12 @@ def read_src_version(): install_requires = [ - "aiotools>=1.5.9", - "async_timeout>=3.0.1", + "aiotools>=1.7.0", "attrs>=21.3.0", "python-dateutil>=2.8.2", "msgpack>=1.0.4", "temporenc>=0.1", + "yarl>=1.8.2", ] build_requires = [ @@ -45,13 +45,13 @@ def read_src_version(): ] lint_requires = [ - "flake8>=5.0.4", - "isort>=5.10.1", - "black>=22.10.0", + "black>=23.7.0", + "ruff>=0.0.287", + "ruff-lsp", ] typecheck_requires = [ - "mypy>=0.991", + "mypy>=1.5.1", "types-python-dateutil", ] @@ -68,9 +68,7 @@ def read_src_version(): "pyzmq>=23.0.0", ] -redis_requires = [ - "aioredis>=1.3.0,<2.0", -] +redis_requires = ["redis>=4.6.0"] snappy_requires = [ "python-snappy>=0.6.1", @@ -113,7 +111,7 @@ def read_src_version(): "callosum": ["py.typed"], }, include_package_data=True, - python_requires=">=3.8", + python_requires=">=3.11", setup_requires=["setuptools>=61.0"], install_requires=install_requires, extras_require={ diff --git a/src/callosum/lower/dispatch_redis.py b/src/callosum/lower/dispatch_redis.py index b789487..71212ae 100644 --- a/src/callosum/lower/dispatch_redis.py +++ b/src/callosum/lower/dispatch_redis.py @@ -3,8 +3,9 @@ import asyncio from typing import Any, AsyncGenerator, Mapping, Optional, Tuple, Union -import aioredis import attrs +import redis +import redis.asyncio from ..abc import RawHeaderBody from ..exceptions import InvalidAddressError @@ -15,6 +16,7 @@ AbstractConnector, BaseTransport, ) +from .redis_common import redis_addr_to_url @attrs.define(auto_attribs=True, slots=True) @@ -48,55 +50,54 @@ def __init__( self.direction_keys = direction_keys async def recv_message(self) -> AsyncGenerator[Optional[RawHeaderBody], None]: - # assert not self.transport._redis.closed + assert self.transport._redis is not None + assert self.addr.group is not None + assert self.addr.consumer is not None if not self.direction_keys: stream_key = self.addr.stream_key else: stream_key = f"{self.addr.stream_key}.{self.direction_keys[0]}" - # _s = asyncio.shield - def _s(x): - return x - - async def _xack(raw_msg): - await self.transport._redis.xack(raw_msg[0], self.addr.group, raw_msg[1]) - try: - raw_msgs = await _s( - self.transport._redis.xread_group( + per_key_fetch_list: list[Any] = [] + while not per_key_fetch_list: + per_key_fetch_list = await self.transport._redis.xreadgroup( self.addr.group, self.addr.consumer, - [stream_key], - latest_ids=[">"], + {stream_key: ">"}, + block=1000, ) - ) - for raw_msg in raw_msgs: - # [0]: stream key, [1]: item ID - if b"meta" in raw_msg[2]: - await _s(_xack(raw_msg)) + for fetch_info in per_key_fetch_list: + if fetch_info[0].decode() != stream_key: continue - yield RawHeaderBody(raw_msg[2][b"hdr"], raw_msg[2][b"msg"], None) - await _s(_xack(raw_msg)) + for item in fetch_info[1]: + item_id: bytes = item[0] + item_data: dict[bytes, bytes] = item[1] + try: + if b"meta" in item_data: + continue + yield RawHeaderBody( + item_data[b"hdr"], + item_data[b"msg"], + None, + ) + finally: + await self.transport._redis.xack( + stream_key, self.addr.group, item_id + ) except asyncio.CancelledError: raise - except aioredis.errors.ConnectionForcedCloseError: + except redis.asyncio.ConnectionError: yield None async def send_message(self, raw_msg: RawHeaderBody) -> None: - # assert not self.transport._redis.closed + assert self.transport._redis is not None if not self.direction_keys: stream_key = self.addr.stream_key else: stream_key = f"{self.addr.stream_key}.{self.direction_keys[1]}" - - # _s = asyncio.shield - def _s(x): - return x - - await _s( - self.transport._redis.xadd( - stream_key, {b"hdr": raw_msg[0], b"msg": raw_msg[1]} - ) + await self.transport._redis.xadd( + stream_key, {b"hdr": raw_msg[0], b"msg": raw_msg[1]} ) @@ -109,14 +110,15 @@ class DispatchRedisBinder(AbstractBinder): who read messages from the stream (Consumers). """ - __slots__ = ("transport", "addr") + __slots__ = ("transport", "addr", "_addr_url") transport: DispatchRedisTransport addr: RedisStreamAddress async def __aenter__(self): - self.transport._redis = await aioredis.create_redis( - self.addr.redis_server, **self.transport._redis_opts + self._addr_url = redis_addr_to_url(self.addr.redis_server) + self.transport._redis = await redis.asyncio.from_url( + self._addr_url, **self.transport._redis_opts ) key = self.addr.stream_key # If there were no stream with the specified key before, @@ -143,31 +145,35 @@ class DispatchRedisConnector(AbstractConnector): that each consumer from the group gets distinct set of messages. """ - __slots__ = ("transport", "addr") + __slots__ = ("transport", "addr", "_addr_url") transport: DispatchRedisTransport addr: RedisStreamAddress async def __aenter__(self): - pool = await aioredis.create_connection( - self.addr.redis_server, **self.transport._redis_opts + assert self.addr.group is not None + assert self.addr.consumer is not None + self._addr_url = redis_addr_to_url(self.addr.redis_server) + self.transport._redis = await redis.asyncio.from_url( + self._addr_url, **self.transport._redis_opts ) - self.transport._redis = aioredis.Redis(pool) key = self.addr.stream_key # If there were no stream with the specified key before, # it is created as a side effect of adding the message. await self.transport._redis.xadd(key, {b"meta": b"create-or-join-to-stream"}) groups = await self.transport._redis.xinfo_groups(key) - if not any(map(lambda g: g[b"name"] == self.addr.group.encode(), groups)): + if not any(map(lambda g: g["name"] == self.addr.group.encode(), groups)): await self.transport._redis.xgroup_create(key, self.addr.group) return DispatchRedisConnection(self.transport, self.addr) async def __aexit__(self, exc_type, exc_obj, exc_tb): + assert self.addr.group is not None + assert self.addr.consumer is not None # we need to create a new Redis connection for cleanup # because self.transport._redis gets corrupted upon # cancellation of Peer._recv_loop() task. - _redis = await aioredis.create_redis( - self.addr.redis_server, **self.transport._redis_opts + _redis = await redis.asyncio.from_url( + self._addr_url, **self.transport._redis_opts ) try: await asyncio.shield( @@ -176,8 +182,7 @@ async def __aexit__(self, exc_type, exc_obj, exc_tb): ) ) finally: - _redis.close() - await _redis.wait_closed() + await _redis.close() class DispatchRedisTransport(BaseTransport): @@ -192,7 +197,7 @@ class DispatchRedisTransport(BaseTransport): ) _redis_opts: Mapping[str, Any] - _redis: aioredis.RedisConnection + _redis: Optional[redis.asyncio.Redis] binder_cls = DispatchRedisBinder connector_cls = DispatchRedisConnector @@ -207,6 +212,5 @@ def __init__( self._redis = None async def close(self) -> None: - if self._redis is not None and not self._redis.closed: - self._redis.close() - await self._redis.wait_closed() + if self._redis is not None: + await self._redis.close() diff --git a/src/callosum/lower/redis_common.py b/src/callosum/lower/redis_common.py new file mode 100644 index 0000000..504640f --- /dev/null +++ b/src/callosum/lower/redis_common.py @@ -0,0 +1,18 @@ +import yarl + + +def redis_addr_to_url( + value: str | tuple[str, int], + *, + scheme: str = "redis", +) -> str: + match value: + case str(): + url = yarl.URL(value) + if url.scheme is None: + return str(yarl.URL(value).with_scheme(scheme)) + return value + case (host, port): + return f"{scheme}://{host}:{port}" + case _: + raise ValueError("unrecognized address format", value) diff --git a/src/callosum/lower/rpc_redis.py b/src/callosum/lower/rpc_redis.py index d5979b0..8824c81 100644 --- a/src/callosum/lower/rpc_redis.py +++ b/src/callosum/lower/rpc_redis.py @@ -3,8 +3,9 @@ import asyncio from typing import Any, AsyncGenerator, Mapping, Optional, Tuple, Union -import aioredis import attrs +import redis +import redis.asyncio from ..abc import RawHeaderBody from . import ( @@ -14,6 +15,7 @@ AbstractConnector, BaseTransport, ) +from .redis_common import redis_addr_to_url @attrs.define(auto_attribs=True, slots=True) @@ -42,55 +44,54 @@ def __init__( self.direction_keys = direction_keys async def recv_message(self) -> AsyncGenerator[Optional[RawHeaderBody], None]: - # assert not self.transport._redis.closed + assert self.transport._redis is not None + assert self.addr.group is not None + assert self.addr.consumer is not None if not self.direction_keys: stream_key = self.addr.stream_key else: stream_key = f"{self.addr.stream_key}.{self.direction_keys[0]}" - # _s = asyncio.shield - def _s(x): - return x - - async def _xack(raw_msg): - await self.transport._redis.xack(raw_msg[0], self.addr.group, raw_msg[1]) - try: - raw_msgs = await _s( - self.transport._redis.xread_group( + per_key_fetch_list: list[Any] = [] + while not per_key_fetch_list: + per_key_fetch_list = await self.transport._redis.xreadgroup( self.addr.group, self.addr.consumer, - [stream_key], - latest_ids=[">"], + {stream_key: ">"}, + block=1000, ) - ) - for raw_msg in raw_msgs: - # [0]: stream key, [1]: item ID - if b"meta" in raw_msg[2]: - await _s(_xack(raw_msg)) + for fetch_info in per_key_fetch_list: + if fetch_info[0].decode() != stream_key: continue - yield RawHeaderBody(raw_msg[2][b"hdr"], raw_msg[2][b"msg"], None) - await _s(_xack(raw_msg)) + for item in fetch_info[1]: + item_id: bytes = item[0] + item_data: dict[bytes, bytes] = item[1] + try: + if b"meta" in item_data: + continue + yield RawHeaderBody( + item_data[b"hdr"], + item_data[b"msg"], + None, + ) + finally: + await self.transport._redis.xack( + stream_key, self.addr.group, item_id + ) except asyncio.CancelledError: raise - except aioredis.errors.ConnectionForcedCloseError: + except redis.asyncio.ConnectionError: yield None async def send_message(self, raw_msg: RawHeaderBody) -> None: - # assert not self.transport._redis.closed + assert self.transport._redis is not None if not self.direction_keys: stream_key = self.addr.stream_key else: stream_key = f"{self.addr.stream_key}.{self.direction_keys[1]}" - - # _s = asyncio.shield - def _s(x): - return x - - await _s( - self.transport._redis.xadd( - stream_key, {b"hdr": raw_msg[0], b"msg": raw_msg[1]} - ) + await self.transport._redis.xadd( + stream_key, {b"hdr": raw_msg[0], b"msg": raw_msg[1]} ) @@ -104,30 +105,39 @@ class RPCRedisBinder(AbstractBinder): the connection. """ - __slots__ = ("transport", "addr") + __slots__ = ("transport", "addr", "_addr_url") transport: RPCRedisTransport addr: RedisStreamAddress + _addr_url: str async def __aenter__(self): - self.transport._redis = await aioredis.create_redis( - self.addr.redis_server, **self.transport._redis_opts + assert self.addr.group is not None + assert self.addr.consumer is not None + self._addr_url = redis_addr_to_url(self.addr.redis_server) + self.transport._redis = await redis.asyncio.from_url( + self._addr_url, **self.transport._redis_opts ) key = f"{self.addr.stream_key}.bind" await self.transport._redis.xadd(key, {b"meta": b"create-stream"}) groups = await self.transport._redis.xinfo_groups(key) - if not any(map(lambda g: g[b"name"] == self.addr.group.encode(), groups)): + print(groups) + if not any(map(lambda g: g["name"] == self.addr.group.encode(), groups)): await self.transport._redis.xgroup_create( - key, self.addr.group - ) # TODO: mkstream=True in future aioredis + key, + self.addr.group, + mkstream=True, + ) return RPCRedisConnection(self.transport, self.addr, ("bind", "conn")) async def __aexit__(self, exc_type, exc_obj, exc_tb): + assert self.addr.group is not None + assert self.addr.consumer is not None # we need to create a new Redis connection for cleanup # because self.transport._redis gets corrupted upon # cancellation of Peer._recv_loop() task. - _redis = await aioredis.create_redis( - self.addr.redis_server, **self.transport._redis_opts + _redis = await redis.asyncio.from_url( + self._addr_url, **self.transport._redis_opts ) try: await asyncio.shield( @@ -138,8 +148,7 @@ async def __aexit__(self, exc_type, exc_obj, exc_tb): ) ) finally: - _redis.close() - await _redis.wait_closed() + await _redis.close() class RPCRedisConnector(AbstractConnector): @@ -152,31 +161,37 @@ class RPCRedisConnector(AbstractConnector): the connection. """ - __slots__ = ("transport", "addr") + __slots__ = ("transport", "addr", "_addr_url") transport: RPCRedisTransport addr: RedisStreamAddress async def __aenter__(self): - pool = await aioredis.create_connection( - self.addr.redis_server, **self.transport._redis_opts + assert self.addr.group is not None + assert self.addr.consumer is not None + self._addr_url = redis_addr_to_url(self.addr.redis_server) + self.transport._redis = redis.asyncio.from_url( + self._addr_url, **self.transport._redis_opts ) - self.transport._redis = aioredis.Redis(pool) key = f"{self.addr.stream_key}.conn" await self.transport._redis.xadd(key, {b"meta": b"create-stream"}) groups = await self.transport._redis.xinfo_groups(key) - if not any(map(lambda g: g[b"name"] == self.addr.group.encode(), groups)): + if not any(map(lambda g: g["name"] == self.addr.group.encode(), groups)): await self.transport._redis.xgroup_create( - key, self.addr.group - ) # TODO: mkstream=True in future aioredis + key, + self.addr.group, + mkstream=True, + ) return RPCRedisConnection(self.transport, self.addr, ("conn", "bind")) async def __aexit__(self, exc_type, exc_obj, exc_tb): + assert self.addr.group is not None + assert self.addr.consumer is not None # we need to create a new Redis connection for cleanup # because self.transport._redis gets corrupted upon # cancellation of Peer._recv_loop() task. - _redis = await aioredis.create_redis( - self.addr.redis_server, **self.transport._redis_opts + _redis = await redis.asyncio.from_url( + self._addr_url, **self.transport._redis_opts ) try: await asyncio.shield( @@ -187,8 +202,7 @@ async def __aexit__(self, exc_type, exc_obj, exc_tb): ) ) finally: - _redis.close() - await _redis.wait_closed() + await _redis.close() class RPCRedisTransport(BaseTransport): @@ -203,7 +217,7 @@ class RPCRedisTransport(BaseTransport): ) _redis_opts: Mapping[str, Any] - _redis: aioredis.RedisConnection + _redis: Optional[redis.asyncio.Redis] binder_cls = RPCRedisBinder connector_cls = RPCRedisConnector @@ -218,6 +232,5 @@ def __init__( self._redis = None async def close(self) -> None: - if self._redis is not None and not self._redis.closed: - self._redis.close() - await self._redis.wait_closed() + if self._redis is not None: + await self._redis.close() diff --git a/src/callosum/rpc/channel.py b/src/callosum/rpc/channel.py index f8c9a5e..5d2fbfd 100644 --- a/src/callosum/rpc/channel.py +++ b/src/callosum/rpc/channel.py @@ -18,7 +18,6 @@ import attrs from aiotools import aclosing -from async_timeout import timeout from ..abc import ( AbstractChannel, @@ -327,7 +326,7 @@ async def invoke( try: request: RPCMessage server_cancelled = False - with timeout(invoke_timeout): + async with asyncio.timeout(invoke_timeout): if callable(body): # The user is using an upper-layer adaptor. async with aclosing(body()) as agen: diff --git a/src/callosum/upper/thrift.py b/src/callosum/upper/thrift.py index 9ae23f5..9867c25 100644 --- a/src/callosum/upper/thrift.py +++ b/src/callosum/upper/thrift.py @@ -3,7 +3,6 @@ import logging from typing import Any, Optional, Sequence -import async_timeout from thriftpy2.contrib.aio.processor import TAsyncProcessor from thriftpy2.contrib.aio.protocol.binary import TAsyncBinaryProtocol from thriftpy2.thrift import TApplicationException, TMessageType, args_to_kwargs @@ -37,7 +36,7 @@ async def handle_function(self, request: RPCMessage) -> bytes: iproto = self._protocol_cls(reader_trans) oproto = self._protocol_cls(writer_trans) try: - with async_timeout.timeout(self._exec_timeout): + async with asyncio.timeout(self._exec_timeout): await self._processor.process(iproto, oproto) except (asyncio.IncompleteReadError, ConnectionError): logger.debug("client has closed the connection")