From 2ae57db6164e347da08b88b35bda932839bfb8af Mon Sep 17 00:00:00 2001 From: MeditationDuck Date: Tue, 10 Sep 2024 14:41:23 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20KeyboardInterrupt=20in=20w?= =?UTF-8?q?ake=20test=20multiprocessing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wake/testing/pytest_plugin_multiprocess.py | 21 +++++++++++++++---- .../pytest_plugin_multiprocess_server.py | 20 +++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/wake/testing/pytest_plugin_multiprocess.py b/wake/testing/pytest_plugin_multiprocess.py index 73a41fcd5..f84ed6a7a 100644 --- a/wake/testing/pytest_plugin_multiprocess.py +++ b/wake/testing/pytest_plugin_multiprocess.py @@ -40,6 +40,7 @@ class PytestWakePluginMultiprocess: _exception_handled: bool _ctx_managers: List + _keyboard_interrupt: bool def __init__( self, @@ -62,6 +63,7 @@ def __init__( self._debug = debug self._exception_handled = False + self._keyboard_interrupt = False self._ctx_managers = [] def _setup_stdio(self): @@ -86,6 +88,11 @@ def _exception_handler( e: Optional[BaseException], tb: Optional[TracebackType], ) -> None: + + # After the keyboard interrupt, we do not interested in debugging. + if self._keyboard_interrupt: + return + self._cleanup_stdio() self._exception_handled = True @@ -177,11 +184,14 @@ def coverage_callback() -> None: except queue.Full: pass - def signal_handler(sig, frame): - raise KeyboardInterrupt() - pickling_support.install() - signal.signal(signal.SIGTERM, signal_handler) + + def sigint_handler(signum, frame): + self._keyboard_interrupt = True + self._queue.put(("keyboard_interrupt", self._index)) + pytest.exit("Keyboard interrupt", returncode=0) + + signal.signal(signal.SIGINT, sigint_handler) if self._debug: set_exception_handler(self._exception_handler) @@ -226,6 +236,9 @@ def pytest_runtest_protocol(self, item, nextitem): # do not forward pytest_runtest_logstart and pytest_runtest_logfinish as they write item location to stdout which may be different for each process def pytest_runtest_logreport(self, report: pytest.TestReport): + # not sending exception report since the reason of exception is keyboard interrupt or at least triggered by keyboard interrupt + if self._keyboard_interrupt: + return self._queue.put(("pytest_runtest_logreport", self._index, report)) def pytest_warning_recorded(self, warning_message, when, nodeid, location): diff --git a/wake/testing/pytest_plugin_multiprocess_server.py b/wake/testing/pytest_plugin_multiprocess_server.py index 2f84a47e1..6fb74cec7 100644 --- a/wake/testing/pytest_plugin_multiprocess_server.py +++ b/wake/testing/pytest_plugin_multiprocess_server.py @@ -1,7 +1,9 @@ import multiprocessing import multiprocessing.connection +import os import pickle import shutil +import signal from contextlib import nullcontext from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -101,14 +103,15 @@ def pytest_sessionstart(self, session: pytest.Session): parent_conn, ) p.start() + signal.signal(signal.SIGINT, signal.SIG_IGN) def pytest_sessionfinish(self, session: pytest.Session): self._queue.cancel_join_thread() for p, conn in self._processes.values(): - p.terminate() + if p.pid is not None: + os.kill(p.pid, signal.SIGINT) p.join() conn.close() - self._queue.close() # flush coverage @@ -177,6 +180,7 @@ def pytest_runtestloop(self, session: pytest.Session): ) try: + keyboard_interrupt = [False for _ in range(self._proc_count)] with ctx as progress: if progress is not None: tasks = [ @@ -268,7 +272,12 @@ def pytest_runtestloop(self, session: pytest.Session): ) elif msg[0] == "pytest_sessionfinish": if progress is not None: - text = f"#{index} finished [green]✓[/green]" if msg[2] == 0 else f"#{index} failed [red]✗[/red]" + if keyboard_interrupt[index]: + text = f"#{index} interrupted [yellow]⚠[/yellow]" + elif msg[2] == 0: + text = f"#{index} finished [green]✓[/green]" + else: + text = f"#{index} failed [red]✗[/red]" progress.update(tasks[index], description=text) self._processes.pop(index) @@ -280,6 +289,11 @@ def pytest_runtestloop(self, session: pytest.Session): session.config.hook.pytest_internalerror( excrepr=exc_info.getrepr(style="short"), excinfo=exc_info ) + elif msg[0] == "keyboard_interrupt": + keyboard_interrupt[index] = True + + if True in keyboard_interrupt: + raise KeyboardInterrupt finally: print("") for report in reports: