From cc2796837285fdb48f5b1ac99c11b58c4aadf2ad Mon Sep 17 00:00:00 2001 From: Diorcet Yann Date: Fri, 13 Dec 2024 13:02:55 +0100 Subject: [PATCH] Don't break internal loops on CancelledError when the cancel is not triggered internally (avoid CTRL-C issue on python < 3.11) --- nats/aio/client.py | 66 +++++++++++++++++++++++++--------------- nats/aio/subscription.py | 7 +++-- nats/js/client.py | 9 +++++- tests/test_js.py | 27 ++++++++++++++++ 4 files changed, 81 insertions(+), 28 deletions(-) diff --git a/nats/aio/client.py b/nats/aio/client.py index 81e65f50..968ec440 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -251,6 +251,7 @@ def __init__(self) -> None: # New style request/response self._resp_map: Dict[str, asyncio.Future] = {} self._resp_sub_prefix: Optional[bytearray] = None + self._sub_prefix_subscription: Optional[Subscription] = None self._nuid = NUID() self._inbox_prefix = bytearray(DEFAULT_INBOX_PREFIX) self._auth_configured: bool = False @@ -680,11 +681,17 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: if self.is_closed: self._status = status return - self._status = Client.CLOSED + + if self._sub_prefix_subscription is not None: + subscription = self._sub_prefix_subscription + self._sub_prefix_subscription = None + await subscription.unsubscribe() # Kick the flusher once again so that Task breaks and avoid pending futures. await self._flush_pending() + self._status = Client.CLOSED + if self._reading_task is not None and not self._reading_task.cancelled( ): self._reading_task.cancel() @@ -726,11 +733,7 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: # Cleanup subscriptions since not reconnecting so no need # to replay the subscriptions anymore. for sub in self._subs.values(): - # Async subs use join when draining already so just cancel here. - if sub._wait_for_msgs_task and not sub._wait_for_msgs_task.done(): - sub._wait_for_msgs_task.cancel() - if sub._message_iterator: - sub._message_iterator._cancel() + sub._stop_processing() # Sync subs may have some inflight next_msg calls that could be blocking # so cancel them here to unblock them. if sub._pending_next_msgs_calls: @@ -985,7 +988,7 @@ async def _init_request_sub(self) -> None: self._resp_sub_prefix.extend(b".") resp_mux_subject = self._resp_sub_prefix[:] resp_mux_subject.extend(b"*") - await self.subscribe( + self._sub_prefix_subscription = await self.subscribe( resp_mux_subject.decode(), cb=self._request_sub_callback ) @@ -2068,23 +2071,28 @@ async def _flusher(self) -> None: if not self.is_connected or self.is_connecting: break - future: asyncio.Future = await self._flush_queue.get() - try: - if self._pending_data_size > 0: - self._transport.writelines(self._pending[:]) - self._pending = [] - self._pending_data_size = 0 - await self._transport.drain() - except OSError as e: - await self._error_cb(e) - await self._process_op_err(e) - break - except (asyncio.CancelledError, RuntimeError, AttributeError): - # RuntimeError in case the event loop is closed - break - finally: - future.set_result(None) + future: asyncio.Future = await self._flush_queue.get() + try: + if self._pending_data_size > 0: + self._transport.writelines(self._pending[:]) + self._pending = [] + self._pending_data_size = 0 + await self._transport.drain() + except OSError as e: + await self._error_cb(e) + await self._process_op_err(e) + break + except (RuntimeError, AttributeError): + # RuntimeError in case the event loop is closed + break + finally: + future.set_result(None) + except asyncio.CancelledError: + if self._status == Client.CLOSED: + break + else: + continue async def _ping_interval(self) -> None: while True: @@ -2098,8 +2106,13 @@ async def _ping_interval(self) -> None: await self._process_op_err(ErrStaleConnection()) return await self._send_ping() - except (asyncio.CancelledError, RuntimeError, AttributeError): + except (RuntimeError, AttributeError): break + except asyncio.CancelledError: + if self._status == Client.CLOSED: + break + else: + continue # except asyncio.InvalidStateError: # pass @@ -2130,7 +2143,10 @@ async def _read_loop(self) -> None: await self._process_op_err(e) break except asyncio.CancelledError: - break + if self._status == Client.CLOSED: + break + else: + continue except Exception as ex: _logger.error("nats: encountered error", exc_info=ex) break diff --git a/nats/aio/subscription.py b/nats/aio/subscription.py index 31fbb887..ad348eab 100644 --- a/nats/aio/subscription.py +++ b/nats/aio/subscription.py @@ -284,7 +284,6 @@ async def unsubscribe(self, limit: int = 0): self._max_msgs = limit if limit == 0 or (self._received >= limit and self._pending_queue.empty()): - self._closed = True self._stop_processing() self._conn._remove_sub(self._id) @@ -295,6 +294,7 @@ def _stop_processing(self) -> None: """ Stops the subscription from processing new messages. """ + self._closed = True if self._wait_for_msgs_task and not self._wait_for_msgs_task.done(): self._wait_for_msgs_task.cancel() if self._message_iterator: @@ -333,7 +333,10 @@ async def _wait_for_msgs(self, error_cb) -> None: and self._pending_queue.empty): self._stop_processing() except asyncio.CancelledError: - break + if self._closed: + break + else: + continue class _SubscriptionMessageIterator: diff --git a/nats/js/client.py b/nats/js/client.py index d26413c0..9dc9eec9 100644 --- a/nats/js/client.py +++ b/nats/js/client.py @@ -968,12 +968,19 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg: self._sub._jsi._fcr = None return msg + async def drain(self): + await self._sub.drain() + self._closed = self._sub._closed + async def unsubscribe(self, limit: int = 0): """ Unsubscribes from a subscription, canceling any heartbeat and flow control tasks, and optionally limits the number of messages to process before unsubscribing. + Nothing is really subscribed from this object, call unsubscribe on underlying sub + and forward _closed flag. """ - await super().unsubscribe(limit) + await self._sub.unsubscribe(limit) + self._closed = self._sub._closed if self._sub._jsi._hbtask: self._sub._jsi._hbtask.cancel() diff --git a/tests/test_js.py b/tests/test_js.py index 85f8f556..2c0ae098 100644 --- a/tests/test_js.py +++ b/tests/test_js.py @@ -314,6 +314,8 @@ async def f(): await task assert received + await nc.close() + @async_test async def test_add_pull_consumer_via_jsm(self): nc = NATS() @@ -339,6 +341,8 @@ async def test_add_pull_consumer_via_jsm(self): info = await js.consumer_info("events", "a") assert 0 == info.num_pending + await nc.close() + @async_long_test async def test_fetch_n(self): nc = NATS() @@ -852,6 +856,7 @@ async def test_ephemeral_pull_subscribe(self): cinfo = await sub.consumer_info() self.assertTrue(cinfo.config.name != None) self.assertTrue(cinfo.config.durable_name == None) + await nc.close() @async_test @@ -896,6 +901,8 @@ async def test_consumer_with_multiple_filters(self): ok = await msgs[0].ack_sync() assert ok + await nc.close() + @async_long_test async def test_add_consumer_with_backoff(self): nc = NATS() @@ -953,6 +960,7 @@ async def cb(msg): # Confirm possible to unmarshal the consumer config. assert info.config.backoff == [1, 2] + await nc.close() @async_long_test @@ -1495,6 +1503,8 @@ async def test_jsm_stream_info_options(self): assert si.state.messages == 5 assert si.state.subjects == None + await nc.close() + class SubscribeTest(SingleJetStreamServerTestCase): @@ -1657,6 +1667,8 @@ async def test_ephemeral_subscribe(self): assert len(info2.name) > 0 assert info1.name != info2.name + await nc.close() + @async_test async def test_subscribe_bind(self): nc = await nats.connect() @@ -1702,6 +1714,8 @@ async def test_subscribe_bind(self): assert info.num_ack_pending == 0 assert info.num_pending == 0 + await nc.close() + @async_test async def test_subscribe_custom_limits(self): errors = [] @@ -1904,6 +1918,8 @@ async def test_ack_v2_tokens(self): tzinfo=datetime.timezone.utc ) + await nc.close() + @async_test async def test_double_acking_pull_subscribe(self): nc = await nats.connect() @@ -2031,6 +2047,8 @@ async def f(): assert task.done() assert received + await nc.close() + class DiscardPolicyTest(SingleJetStreamServerTestCase): @@ -2516,6 +2534,7 @@ async def cb(msg): await asyncio.wait_for(done, 10) await nc.close() + await nc2.close() @async_test async def test_recreate_consumer_on_failed_hbs(self): @@ -2548,6 +2567,8 @@ async def error_handler(e): self.assertTrue(orig_name != info.name) await js.delete_stream("MY_STREAM") + await nc.close() + class KVTest(SingleJetStreamServerTestCase): @@ -2667,6 +2688,8 @@ async def error_handler(e): with pytest.raises(BadBucketError): await js.key_value(bucket="TEST3") + await nc.close() + @async_test async def test_kv_basic(self): errors = [] @@ -2824,6 +2847,8 @@ async def error_handler(e): entry = await kv.get("age") assert entry.revision == 10 + await nc.close() + @async_test async def test_kv_direct_get_msg(self): errors = [] @@ -2879,6 +2904,8 @@ async def error_handler(e): ) assert msg.data == b"33" + await nc.close() + @async_test async def test_kv_direct(self): errors = []