Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Oct 30, 2024
1 parent 7c4545b commit b797616
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 144 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ classifiers = [
]
dependencies = [
"cffi; implementation_name == 'pypy'",
"anyioutils >=0.4.1"
"anyioutils >=0.4.2"
]
description = "Python bindings for 0MQ"
readme = "README.md"
Expand Down
188 changes: 95 additions & 93 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import zmq
import zmq.asyncio as zaio
from anyio import create_task_group, move_on_after, sleep
from anyioutils import CancelledError, create_task

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -104,40 +106,34 @@ async def test_recv_json(push_pull):


async def test_recv_json_cancelled(push_pull):
a, b = push_pull
f = b.recv_json()
f.cancel()
# cycle eventloop to allow cancel events to fire
await asyncio.sleep(0)
obj = dict(a=5)
await a.send_json(obj)
# CancelledError change in 3.8 https://bugs.python.org/issue32528
if sys.version_info < (3, 8):
with pytest.raises(CancelledError):
recvd = await f
else:
with pytest.raises(asyncio.exceptions.CancelledError):
async with create_task_group() as tg:
a, b = push_pull
f = create_task(b.recv_json(), tg)
f.cancel(raise_exception=False)
# cycle eventloop to allow cancel events to fire
await sleep(0)
obj = dict(a=5)
await a.send_json(obj)
recvd = await f.wait()
assert f.cancelled()
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
with move_on_after(5):
recvd = await f
assert f.done()
# give it a chance to incorrectly consume the event
events = await b.poll(timeout=5)
assert events
await asyncio.sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == obj
assert recvd == obj


async def test_recv_pyobj(push_pull):
a, b = push_pull
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
await a.send_pyobj(obj)
recvd = await f
assert f.done()
assert f.result() == obj
assert recvd == obj


Expand Down Expand Up @@ -196,85 +192,90 @@ async def test_custom_serialize_error(dealer_router):
async def test_recv_dontwait(push_pull):
push, pull = push_pull
f = pull.recv(zmq.DONTWAIT)
with pytest.raises(zmq.Again):
with pytest.raises(BaseExceptionGroup) as excinfo:
await f
assert excinfo.group_contains(zmq.Again)
await push.send(b"ping")
await pull.poll() # ensure message will be waiting
f = pull.recv(zmq.DONTWAIT)
assert f.done()
msg = await f
msg = await pull.recv(zmq.DONTWAIT)
assert msg == b"ping"


async def test_recv_cancel(push_pull):
a, b = push_pull
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]
async with create_task_group() as tg:
a, b = push_pull
f1 = create_task(b.recv(), tg)
f2 = create_task(b.recv_multipart(), tg)
f1.cancel(raise_exception=False)
assert f1.done()
assert not f2.done()
await a.send_multipart([b"hi", b"there"])
recvd = await f2.wait()
assert f1.cancelled()
assert f2.done()
assert recvd == [b"hi", b"there"]


async def test_poll(push_pull):
a, b = push_pull
f = b.poll(timeout=0)
await asyncio.sleep(0)
assert f.result() == 0
async with create_task_group() as tg:
a, b = push_pull
f = create_task(b.poll(timeout=0), tg)
await sleep(0.01)
assert f.result() == 0

f = b.poll(timeout=1)
assert not f.done()
evt = await f
f = create_task(b.poll(timeout=1), tg)
assert not f.done()
evt = await f.wait()

assert evt == 0
assert evt == 0

f = b.poll(timeout=1000)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]
f = create_task(b.poll(timeout=1000), tg)
assert not f.done()
await a.send_multipart([b"hi", b"there"])
evt = await f.wait()
assert evt == zmq.POLLIN
recvd = await b.recv_multipart()
assert recvd == [b"hi", b"there"]


async def test_poll_base_socket(sockets):
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
sockets.extend([a, b])
a.bind(url)
b.connect(url)

poller = zaio.Poller()
poller.register(b, zmq.POLLIN)

f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]
async with create_task_group() as tg:
ctx = zmq.Context()
url = "inproc://test"
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
sockets.extend([a, b])
a.bind(url)
b.connect(url)

poller = zaio.Poller()
poller.register(b, zmq.POLLIN)

f = create_task(poller.poll(timeout=1000), tg)
assert not f.done()
a.send_multipart([b"hi", b"there"])
evt = await f.wait()
assert evt == [(b, zmq.POLLIN)]
recvd = b.recv_multipart()
assert recvd == [b"hi", b"there"]


async def test_poll_on_closed_socket(push_pull):
a, b = push_pull
with pytest.raises(BaseExceptionGroup) as excinfo:
async with create_task_group() as tg:
a, b = push_pull

f = b.poll(timeout=1)
b.close()
f = create_task(b.poll(timeout=1), tg)
b.close()

# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await asyncio.sleep(0)
if f.cancelled():
break
assert f.cancelled()
# The test might stall if we try to await f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
await sleep(0)
if f.cancelled():
break
assert f.done()
assert excinfo.group_contains(zmq.error.ZMQError)


@pytest.mark.skipif(
Expand Down Expand Up @@ -334,16 +335,17 @@ def test_shadow():


async def test_poll_leak():
ctx = zmq.asyncio.Context()
with ctx, ctx.socket(zmq.PULL) as s:
assert len(s._recv_futures) == 0
for i in range(10):
f = asyncio.ensure_future(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN))
f.cancel()
await asyncio.sleep(0)
# one more sleep allows further chained cleanup
await asyncio.sleep(0.1)
assert len(s._recv_futures) == 0
async with create_task_group() as tg:
ctx = zmq.asyncio.Context()
with ctx, ctx.socket(zmq.PULL) as s:
assert len(s._recv_futures) == 0
for i in range(10):
f = create_task(s.poll(timeout=1000, flags=zmq.PollEvent.POLLIN), tg)
f.cancel(raise_exception=False)
await sleep(0)
# one more sleep allows further chained cleanup
await sleep(0.1)
assert len(s._recv_futures) == 0


class ProcessForTeardownTest(Process):
Expand Down
85 changes: 43 additions & 42 deletions zmq/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)

from anyio import create_task_group, sleep
from anyioutils import Future, create_task
from anyioutils import Future, Task, create_task

import zmq as _zmq
from zmq import EVENTS, POLLIN, POLLOUT
Expand Down Expand Up @@ -105,9 +105,9 @@ def _clear_wrapper_io(f):
socket = self._socket_class.from_socket(socket)
wrapped_sockets.append(socket)
if mask & _zmq.POLLIN:
socket._add_recv_event(tg, 'poll', future=watcher)
create_task(socket._add_recv_event(tg, 'poll', future=watcher), tg)
if mask & _zmq.POLLOUT:
socket._add_send_event(tg, 'poll', future=watcher)
create_task(socket._add_send_event(tg, 'poll', future=watcher), tg)
else:
raw_sockets.append(socket)
evt = 0
Expand All @@ -122,7 +122,7 @@ def on_poll_ready(f):
return
if watcher.cancelled():
try:
future.cancel()
future.cancel(raise_exception=False)
except RuntimeError:
# RuntimeError may be called during teardown
pass
Expand Down Expand Up @@ -152,17 +152,17 @@ async def trigger_timeout():
timeout_handle = create_task(trigger_timeout(), tg)

def cancel_timeout(f):
timeout_handle.cancel()
timeout_handle.cancel(raise_exception=False)

future.add_done_callback(cancel_timeout)

def cancel_watcher(f):
if not watcher.done():
watcher.cancel()
watcher.cancel(raise_exception=False)

future.add_done_callback(cancel_watcher)

return future.wait()
return await future.wait()


class _NoTimer:
Expand Down Expand Up @@ -226,7 +226,7 @@ def close(self, linger: int | None = None) -> None:
for event in event_list:
if not event.future.done():
try:
event.future.cancel()
event.future.cancel(raise_exception=False)
except RuntimeError:
# RuntimeError may be called during teardown
pass
Expand Down Expand Up @@ -359,46 +359,47 @@ async def poll(self, timeout=None, flags=_zmq.POLLIN) -> int: # type: ignore
if self.closed:
raise _zmq.ZMQError(_zmq.ENOTSUP)

p = self._poller_class()
p.register(self, flags)
poll_future = cast(Future, p.poll(timeout))
async with create_task_group() as tg:
p = self._poller_class()
p.register(self, flags)
poll_future = cast(Task, create_task(p.poll(timeout), tg))

future = self._Future()
future = self._Future()

def unwrap_result(f):
if future.done():
return
if poll_future.cancelled():
try:
future.cancel()
except RuntimeError:
# RuntimeError may be called during teardown
pass
return
if f.exception():
future.set_exception(poll_future.exception())
else:
evts = dict(poll_future.result())
future.set_result(evts.get(self, 0))
def unwrap_result(f):
if future.done():
return
if poll_future.cancelled():
try:
future.cancel()
except RuntimeError:
# RuntimeError may be called during teardown
pass
return
if f.exception():
future.set_exception(poll_future.exception())
else:
evts = dict(poll_future.result())
future.set_result(evts.get(self, 0))

if poll_future.done():
# hook up result if already done
unwrap_result(poll_future)
else:
poll_future.add_done_callback(unwrap_result)
if poll_future.done():
# hook up result if already done
unwrap_result(poll_future)
else:
poll_future.add_done_callback(unwrap_result)

def cancel_poll(future):
"""Cancel underlying poll if request has been cancelled"""
if not poll_future.done():
try:
poll_future.cancel()
except RuntimeError:
# RuntimeError may be called during teardown
pass
def cancel_poll(future):
"""Cancel underlying poll if request has been cancelled"""
if not poll_future.done():
try:
poll_future.cancel()
except RuntimeError:
# RuntimeError may be called during teardown
pass

future.add_done_callback(cancel_poll)
future.add_done_callback(cancel_poll)

return await future.wait()
return await future.wait()

def _add_timeout(self, task_group, future, timeout):
"""Add a timeout for a send or recv Future"""
Expand Down
Loading

0 comments on commit b797616

Please sign in to comment.