Skip to content

Commit

Permalink
Don't break internal loops on CancelledError when the cancel is not t…
Browse files Browse the repository at this point in the history
…riggered internally (avoid CTRL-C issue on python < 3.11)
  • Loading branch information
diorcety committed Dec 15, 2024
1 parent 7e7883e commit af1be5e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 28 deletions.
66 changes: 41 additions & 25 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(self) -> None:
# New style request/response
self._resp_map: Dict[str, asyncio.Future] = {}
self._resp_sub_prefix: Optional[bytearray] = None
self._sub_prefix_subscription: Optional[Subscription] = None
self._nuid = NUID()
self._inbox_prefix = bytearray(DEFAULT_INBOX_PREFIX)
self._auth_configured: bool = False
Expand Down Expand Up @@ -680,11 +681,17 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
if self.is_closed:
self._status = status
return
self._status = Client.CLOSED

if self._sub_prefix_subscription is not None:
subscription = self._sub_prefix_subscription
self._sub_prefix_subscription = None
await subscription.unsubscribe()

# Kick the flusher once again so that Task breaks and avoid pending futures.
await self._flush_pending()

self._status = Client.CLOSED

if self._reading_task is not None and not self._reading_task.cancelled(
):
self._reading_task.cancel()
Expand Down Expand Up @@ -726,11 +733,7 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
# Cleanup subscriptions since not reconnecting so no need
# to replay the subscriptions anymore.
for sub in self._subs.values():
# Async subs use join when draining already so just cancel here.
if sub._wait_for_msgs_task and not sub._wait_for_msgs_task.done():
sub._wait_for_msgs_task.cancel()
if sub._message_iterator:
sub._message_iterator._cancel()
sub._stop_processing()
# Sync subs may have some inflight next_msg calls that could be blocking
# so cancel them here to unblock them.
if sub._pending_next_msgs_calls:
Expand Down Expand Up @@ -985,7 +988,7 @@ async def _init_request_sub(self) -> None:
self._resp_sub_prefix.extend(b".")
resp_mux_subject = self._resp_sub_prefix[:]
resp_mux_subject.extend(b"*")
await self.subscribe(
self._sub_prefix_subscription = await self.subscribe(
resp_mux_subject.decode(), cb=self._request_sub_callback
)

Expand Down Expand Up @@ -2068,23 +2071,28 @@ async def _flusher(self) -> None:
if not self.is_connected or self.is_connecting:
break

future: asyncio.Future = await self._flush_queue.get()

try:
if self._pending_data_size > 0:
self._transport.writelines(self._pending[:])
self._pending = []
self._pending_data_size = 0
await self._transport.drain()
except OSError as e:
await self._error_cb(e)
await self._process_op_err(e)
break
except (asyncio.CancelledError, RuntimeError, AttributeError):
# RuntimeError in case the event loop is closed
break
finally:
future.set_result(None)
future: asyncio.Future = await self._flush_queue.get()
try:
if self._pending_data_size > 0:
self._transport.writelines(self._pending[:])
self._pending = []
self._pending_data_size = 0
await self._transport.drain()
except OSError as e:
await self._error_cb(e)
await self._process_op_err(e)
break
except (RuntimeError, AttributeError):
# RuntimeError in case the event loop is closed
break
finally:
future.set_result(None)
except asyncio.CancelledError:
if self._status == Client.CLOSED:
break
else:
continue

async def _ping_interval(self) -> None:
while True:
Expand All @@ -2098,8 +2106,13 @@ async def _ping_interval(self) -> None:
await self._process_op_err(ErrStaleConnection())
return
await self._send_ping()
except (asyncio.CancelledError, RuntimeError, AttributeError):
except (RuntimeError, AttributeError):
break
except asyncio.CancelledError:
if self._status == Client.CLOSED:
break
else:
continue
# except asyncio.InvalidStateError:
# pass

Expand Down Expand Up @@ -2130,7 +2143,10 @@ async def _read_loop(self) -> None:
await self._process_op_err(e)
break
except asyncio.CancelledError:
break
if self._status == Client.CLOSED:
break
else:
continue
except Exception as ex:
_logger.error("nats: encountered error", exc_info=ex)
break
Expand Down
7 changes: 5 additions & 2 deletions nats/aio/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ async def unsubscribe(self, limit: int = 0):
self._max_msgs = limit
if limit == 0 or (self._received >= limit
and self._pending_queue.empty()):
self._closed = True
self._stop_processing()
self._conn._remove_sub(self._id)

Expand All @@ -295,6 +294,7 @@ def _stop_processing(self) -> None:
"""
Stops the subscription from processing new messages.
"""
self._closed = True
if self._wait_for_msgs_task and not self._wait_for_msgs_task.done():
self._wait_for_msgs_task.cancel()
if self._message_iterator:
Expand Down Expand Up @@ -333,7 +333,10 @@ async def _wait_for_msgs(self, error_cb) -> None:
and self._pending_queue.empty):
self._stop_processing()
except asyncio.CancelledError:
break
if self._closed:
break
else:
continue


class _SubscriptionMessageIterator:
Expand Down
9 changes: 8 additions & 1 deletion nats/js/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,12 +968,19 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
self._sub._jsi._fcr = None
return msg

async def drain(self):
await self._sub.drain()
self._closed = self._sub._closed

async def unsubscribe(self, limit: int = 0):
"""
Unsubscribes from a subscription, canceling any heartbeat and flow control tasks,
and optionally limits the number of messages to process before unsubscribing.
Nothing is really subscribed from this object, call unsubscribe on underlying sub
and forward _closed flag.
"""
await super().unsubscribe(limit)
await self._sub.unsubscribe(limit)
self._closed = self._sub._closed

if self._sub._jsi._hbtask:
self._sub._jsi._hbtask.cancel()
Expand Down
14 changes: 14 additions & 0 deletions tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ async def f():
await task
assert received

await nc.close()

@async_test
async def test_add_pull_consumer_via_jsm(self):
nc = NATS()
Expand All @@ -339,6 +341,8 @@ async def test_add_pull_consumer_via_jsm(self):
info = await js.consumer_info("events", "a")
assert 0 == info.num_pending

await nc.close()

@async_long_test
async def test_fetch_n(self):
nc = NATS()
Expand Down Expand Up @@ -852,6 +856,7 @@ async def test_ephemeral_pull_subscribe(self):
cinfo = await sub.consumer_info()
self.assertTrue(cinfo.config.name != None)
self.assertTrue(cinfo.config.durable_name == None)

await nc.close()

@async_test
Expand Down Expand Up @@ -896,6 +901,8 @@ async def test_consumer_with_multiple_filters(self):
ok = await msgs[0].ack_sync()
assert ok

await nc.close()

@async_long_test
async def test_add_consumer_with_backoff(self):
nc = NATS()
Expand Down Expand Up @@ -953,6 +960,7 @@ async def cb(msg):

# Confirm possible to unmarshal the consumer config.
assert info.config.backoff == [1, 2]

await nc.close()

@async_long_test
Expand Down Expand Up @@ -1495,6 +1503,8 @@ async def test_jsm_stream_info_options(self):
assert si.state.messages == 5
assert si.state.subjects == None

await nc.close()


class SubscribeTest(SingleJetStreamServerTestCase):

Expand Down Expand Up @@ -1657,6 +1667,8 @@ async def test_ephemeral_subscribe(self):
assert len(info2.name) > 0
assert info1.name != info2.name

await nc.close()

@async_test
async def test_subscribe_bind(self):
nc = await nats.connect()
Expand Down Expand Up @@ -1702,6 +1714,8 @@ async def test_subscribe_bind(self):
assert info.num_ack_pending == 0
assert info.num_pending == 0

await nc.close()

@async_test
async def test_subscribe_custom_limits(self):
errors = []
Expand Down

0 comments on commit af1be5e

Please sign in to comment.