Skip to content

Commit

Permalink
setup: Upgrade thriftpy2 and replace aioredis with redis-py (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Sep 6, 2023
1 parent 277a0af commit c54ca51
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 129 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ repos:
- id: mypy
additional_dependencies: [
'types-python-dateutil',
'types-redis',
'attrs>=21.3',
]
11 changes: 4 additions & 7 deletions examples/simple-client-thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import sys
import textwrap

from async_timeout import timeout

from callosum.rpc import Peer, RPCUserError
from callosum.serialize import noop_serializer, noop_deserializer
import thriftpy2 as thriftpy
from callosum.lower.zeromq import ZeroMQAddress, ZeroMQRPCTransport
from callosum.rpc import Peer, RPCUserError
from callosum.serialize import noop_deserializer, noop_serializer
from callosum.upper.thrift import ThriftClientAdaptor
import thriftpy2 as thriftpy


simple_thrift = thriftpy.load(
str(pathlib.Path(__file__).parent / "simple.thrift"), module_name="simple_thrift"
Expand Down Expand Up @@ -47,7 +44,7 @@ async def call() -> None:
print(textwrap.indent(e.traceback, prefix="| "))

try:
with timeout(0.5):
async with asyncio.timeout(0.5):
await peer.invoke("simple", adaptor.long_delay())
except asyncio.TimeoutError:
print(
Expand Down
8 changes: 3 additions & 5 deletions examples/simple-client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import asyncio
import json
import random
import os
import random
import secrets
import sys
import textwrap
import traceback

from async_timeout import timeout

from callosum.rpc import Peer, RPCUserError
from callosum.lower.zeromq import ZeroMQAddress, ZeroMQRPCTransport
from callosum.rpc import Peer, RPCUserError


async def test_simple(peer, initial_delay: float = 0):
Expand Down Expand Up @@ -41,7 +39,7 @@ async def test_simple(peer, initial_delay: float = 0):

async def test_timeout(peer):
try:
with timeout(0.5):
async with asyncio.timeout(0.5):
await peer.invoke(
"long_delay",
{
Expand Down
18 changes: 8 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def read_src_version():


install_requires = [
"aiotools>=1.5.9",
"async_timeout>=3.0.1",
"aiotools>=1.7.0",
"attrs>=21.3.0",
"python-dateutil>=2.8.2",
"msgpack>=1.0.4",
"temporenc>=0.1",
"yarl>=1.8.2",
]

build_requires = [
Expand All @@ -45,13 +45,13 @@ def read_src_version():
]

lint_requires = [
"flake8>=5.0.4",
"isort>=5.10.1",
"black>=22.10.0",
"black>=23.7.0",
"ruff>=0.0.287",
"ruff-lsp",
]

typecheck_requires = [
"mypy>=0.991",
"mypy>=1.5.1",
"types-python-dateutil",
]

Expand All @@ -68,9 +68,7 @@ def read_src_version():
"pyzmq>=23.0.0",
]

redis_requires = [
"aioredis>=1.3.0,<2.0",
]
redis_requires = ["redis>=4.6.0"]

snappy_requires = [
"python-snappy>=0.6.1",
Expand Down Expand Up @@ -113,7 +111,7 @@ def read_src_version():
"callosum": ["py.typed"],
},
include_package_data=True,
python_requires=">=3.8",
python_requires=">=3.11",
setup_requires=["setuptools>=61.0"],
install_requires=install_requires,
extras_require={
Expand Down
98 changes: 51 additions & 47 deletions src/callosum/lower/dispatch_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import asyncio
from typing import Any, AsyncGenerator, Mapping, Optional, Tuple, Union

import aioredis
import attrs
import redis
import redis.asyncio

from ..abc import RawHeaderBody
from ..exceptions import InvalidAddressError
Expand All @@ -15,6 +16,7 @@
AbstractConnector,
BaseTransport,
)
from .redis_common import redis_addr_to_url


@attrs.define(auto_attribs=True, slots=True)
Expand Down Expand Up @@ -48,55 +50,54 @@ def __init__(
self.direction_keys = direction_keys

async def recv_message(self) -> AsyncGenerator[Optional[RawHeaderBody], None]:
# assert not self.transport._redis.closed
assert self.transport._redis is not None
assert self.addr.group is not None
assert self.addr.consumer is not None
if not self.direction_keys:
stream_key = self.addr.stream_key
else:
stream_key = f"{self.addr.stream_key}.{self.direction_keys[0]}"

# _s = asyncio.shield
def _s(x):
return x

async def _xack(raw_msg):
await self.transport._redis.xack(raw_msg[0], self.addr.group, raw_msg[1])

try:
raw_msgs = await _s(
self.transport._redis.xread_group(
per_key_fetch_list: list[Any] = []
while not per_key_fetch_list:
per_key_fetch_list = await self.transport._redis.xreadgroup(
self.addr.group,
self.addr.consumer,
[stream_key],
latest_ids=[">"],
{stream_key: ">"},
block=1000,
)
)
for raw_msg in raw_msgs:
# [0]: stream key, [1]: item ID
if b"meta" in raw_msg[2]:
await _s(_xack(raw_msg))
for fetch_info in per_key_fetch_list:
if fetch_info[0].decode() != stream_key:
continue
yield RawHeaderBody(raw_msg[2][b"hdr"], raw_msg[2][b"msg"], None)
await _s(_xack(raw_msg))
for item in fetch_info[1]:
item_id: bytes = item[0]
item_data: dict[bytes, bytes] = item[1]
try:
if b"meta" in item_data:
continue
yield RawHeaderBody(
item_data[b"hdr"],
item_data[b"msg"],
None,
)
finally:
await self.transport._redis.xack(
stream_key, self.addr.group, item_id
)
except asyncio.CancelledError:
raise
except aioredis.errors.ConnectionForcedCloseError:
except redis.asyncio.ConnectionError:
yield None

async def send_message(self, raw_msg: RawHeaderBody) -> None:
# assert not self.transport._redis.closed
assert self.transport._redis is not None
if not self.direction_keys:
stream_key = self.addr.stream_key
else:
stream_key = f"{self.addr.stream_key}.{self.direction_keys[1]}"

# _s = asyncio.shield
def _s(x):
return x

await _s(
self.transport._redis.xadd(
stream_key, {b"hdr": raw_msg[0], b"msg": raw_msg[1]}
)
await self.transport._redis.xadd(
stream_key, {b"hdr": raw_msg[0], b"msg": raw_msg[1]}
)


Expand All @@ -109,14 +110,15 @@ class DispatchRedisBinder(AbstractBinder):
who read messages from the stream (Consumers).
"""

__slots__ = ("transport", "addr")
__slots__ = ("transport", "addr", "_addr_url")

transport: DispatchRedisTransport
addr: RedisStreamAddress

async def __aenter__(self):
self.transport._redis = await aioredis.create_redis(
self.addr.redis_server, **self.transport._redis_opts
self._addr_url = redis_addr_to_url(self.addr.redis_server)
self.transport._redis = await redis.asyncio.from_url(
self._addr_url, **self.transport._redis_opts
)
key = self.addr.stream_key
# If there were no stream with the specified key before,
Expand All @@ -143,31 +145,35 @@ class DispatchRedisConnector(AbstractConnector):
that each consumer from the group gets distinct set of messages.
"""

__slots__ = ("transport", "addr")
__slots__ = ("transport", "addr", "_addr_url")

transport: DispatchRedisTransport
addr: RedisStreamAddress

async def __aenter__(self):
pool = await aioredis.create_connection(
self.addr.redis_server, **self.transport._redis_opts
assert self.addr.group is not None
assert self.addr.consumer is not None
self._addr_url = redis_addr_to_url(self.addr.redis_server)
self.transport._redis = await redis.asyncio.from_url(
self._addr_url, **self.transport._redis_opts
)
self.transport._redis = aioredis.Redis(pool)
key = self.addr.stream_key
# If there were no stream with the specified key before,
# it is created as a side effect of adding the message.
await self.transport._redis.xadd(key, {b"meta": b"create-or-join-to-stream"})
groups = await self.transport._redis.xinfo_groups(key)
if not any(map(lambda g: g[b"name"] == self.addr.group.encode(), groups)):
if not any(map(lambda g: g["name"] == self.addr.group.encode(), groups)):
await self.transport._redis.xgroup_create(key, self.addr.group)
return DispatchRedisConnection(self.transport, self.addr)

async def __aexit__(self, exc_type, exc_obj, exc_tb):
assert self.addr.group is not None
assert self.addr.consumer is not None
# we need to create a new Redis connection for cleanup
# because self.transport._redis gets corrupted upon
# cancellation of Peer._recv_loop() task.
_redis = await aioredis.create_redis(
self.addr.redis_server, **self.transport._redis_opts
_redis = await redis.asyncio.from_url(
self._addr_url, **self.transport._redis_opts
)
try:
await asyncio.shield(
Expand All @@ -176,8 +182,7 @@ async def __aexit__(self, exc_type, exc_obj, exc_tb):
)
)
finally:
_redis.close()
await _redis.wait_closed()
await _redis.close()


class DispatchRedisTransport(BaseTransport):
Expand All @@ -192,7 +197,7 @@ class DispatchRedisTransport(BaseTransport):
)

_redis_opts: Mapping[str, Any]
_redis: aioredis.RedisConnection
_redis: Optional[redis.asyncio.Redis]

binder_cls = DispatchRedisBinder
connector_cls = DispatchRedisConnector
Expand All @@ -207,6 +212,5 @@ def __init__(
self._redis = None

async def close(self) -> None:
if self._redis is not None and not self._redis.closed:
self._redis.close()
await self._redis.wait_closed()
if self._redis is not None:
await self._redis.close()
18 changes: 18 additions & 0 deletions src/callosum/lower/redis_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import yarl


def redis_addr_to_url(
value: str | tuple[str, int],
*,
scheme: str = "redis",
) -> str:
match value:
case str():
url = yarl.URL(value)
if url.scheme is None:
return str(yarl.URL(value).with_scheme(scheme))
return value
case (host, port):
return f"{scheme}://{host}:{port}"
case _:
raise ValueError("unrecognized address format", value)
Loading

0 comments on commit c54ca51

Please sign in to comment.