Skip to content

Commit

Permalink
query: interrupt query script on SIGTERM (#858)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Jan 28, 2025
1 parent 2520cae commit c0f23b4
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 6 deletions.
89 changes: 84 additions & 5 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import os.path
import posixpath
import signal
import subprocess
import sys
import time
Expand Down Expand Up @@ -97,6 +98,47 @@ def noop(_: str):
pass


class TerminationSignal(RuntimeError): # noqa: N818
def __init__(self, signal):
self.signal = signal
super().__init__("Received termination signal", signal)

def __repr__(self):
return f"{self.__class__.__name__}({self.signal})"


if sys.platform == "win32":
SIGINT = signal.CTRL_C_EVENT
else:
SIGINT = signal.SIGINT


def shutdown_process(
proc: subprocess.Popen,
interrupt_timeout: Optional[int] = None,
terminate_timeout: Optional[int] = None,
) -> int:
"""Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""

logger.info("sending interrupt signal to the process %s", proc.pid)
proc.send_signal(SIGINT)

logger.info("waiting for the process %s to finish", proc.pid)
try:
return proc.wait(interrupt_timeout)
except subprocess.TimeoutExpired:
logger.info(
"timed out waiting, sending terminate signal to the process %s", proc.pid
)
proc.terminate()
try:
return proc.wait(terminate_timeout)
except subprocess.TimeoutExpired:
logger.info("timed out waiting, killing the process %s", proc.pid)
proc.kill()
return proc.wait()


def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
buffer = b""
while byt := stream.read(1): # Read one byte at a time
Expand Down Expand Up @@ -1493,6 +1535,8 @@ def query(
output_hook: Callable[[str], None] = noop,
params: Optional[dict[str, str]] = None,
job_id: Optional[str] = None,
interrupt_timeout: Optional[int] = None,
terminate_timeout: Optional[int] = None,
) -> None:
cmd = [python_executable, "-c", query_script]
env = dict(env or os.environ)
Expand All @@ -1506,13 +1550,48 @@ def query(
if capture_output:
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}

def raise_termination_signal(sig: int, _: Any) -> NoReturn:
raise TerminationSignal(sig)

thread: Optional[Thread] = None
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
if capture_output:
args = (proc.stdout, output_hook)
thread = Thread(target=_process_stream, args=args, daemon=True)
thread.start()
thread.join() # wait for the reader thread
logger.info("Starting process %s", proc.pid)

orig_sigint_handler = signal.getsignal(signal.SIGINT)
# ignore SIGINT in the main process.
# In the terminal, SIGINTs are received by all the processes in
# the foreground process group, so the script will receive the signal too.
# (If we forward the signal to the child, it will receive it twice.)
signal.signal(signal.SIGINT, signal.SIG_IGN)

orig_sigterm_handler = signal.getsignal(signal.SIGTERM)
signal.signal(signal.SIGTERM, raise_termination_signal)
try:
if capture_output:
args = (proc.stdout, output_hook)
thread = Thread(target=_process_stream, args=args, daemon=True)
thread.start()

proc.wait()
except TerminationSignal as exc:
signal.signal(signal.SIGTERM, orig_sigterm_handler)
signal.signal(signal.SIGINT, orig_sigint_handler)
logging.info("Shutting down process %s, received %r", proc.pid, exc)
# Rather than forwarding the signal to the child, we try to shut it down
# gracefully. This is because we consider the script to be interactive
# and special, so we give it time to cleanup before exiting.
shutdown_process(proc, interrupt_timeout, terminate_timeout)
if proc.returncode:
raise QueryScriptCancelError(
"Query script was canceled by user", return_code=proc.returncode
) from exc
finally:
signal.signal(signal.SIGTERM, orig_sigterm_handler)
signal.signal(signal.SIGINT, orig_sigint_handler)
if thread:
thread.join() # wait for the reader thread

logging.info("Process %s exited with return code %s", proc.pid, proc.returncode)
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
raise QueryScriptCancelError(
"Query script was canceled by user",
Expand Down
79 changes: 79 additions & 0 deletions tests/func/test_query.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os.path
import signal
import sys
from multiprocessing.pool import ExceptionWithTraceback # type: ignore[attr-defined]
from textwrap import dedent

import cloudpickle
import multiprocess
import pytest

from datachain.catalog import Catalog
from datachain.cli import query
from datachain.data_storage import AbstractDBMetastore, JobQueryType, JobStatus
from datachain.error import QueryScriptCancelError
from datachain.job import Job
from tests.utils import wait_for_condition


@pytest.fixture
Expand Down Expand Up @@ -102,3 +109,75 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath,
assert job.params == {"url": src_uri}
assert job.metrics == {"count": 7}
assert job.python_version == f"{sys.version_info.major}.{sys.version_info.minor}"


if sys.platform == "win32":
SIGKILL = signal.SIGTERM
else:
SIGKILL = signal.SIGKILL


@pytest.mark.skipif(sys.platform == "win32", reason="Windows does not have SIGTERM")
@pytest.mark.parametrize(
"setup,expected_return_code",
[
("", -signal.SIGINT),
("signal.signal(signal.SIGINT, signal.SIG_IGN)", -signal.SIGTERM),
(
"""\
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
""",
-SIGKILL,
),
],
)
def test_shutdown_on_sigterm(tmp_dir, request, catalog, setup, expected_return_code):
query = f"""\
import os, pathlib, signal, sys, time
pathlib.Path("ready").touch(exist_ok=False)
{setup}
time.sleep(10)
"""

def apply(f, args, kwargs):
return f(*args, **kwargs)

def func(ms_params, wh_params, init_params, q):
catalog = Catalog(apply(*ms_params), apply(*wh_params), **init_params)
try:
catalog.query(query, interrupt_timeout=0.5, terminate_timeout=0.5)
except Exception as e: # noqa: BLE001
q.put(ExceptionWithTraceback(e, e.__traceback__))
else:
q.put(None)

mp_ctx = multiprocess.get_context("spawn")
q = mp_ctx.Queue()
p = mp_ctx.Process(
target=func,
args=(
catalog.metastore.clone_params(),
catalog.warehouse.clone_params(),
catalog.get_init_params(),
q,
),
)
p.start()
request.addfinalizer(p.kill)

def is_ready():
assert p.is_alive(), "Process is dead"
return os.path.exists("ready")

# make sure the process is running before we send the signal
wait_for_condition(is_ready, "script to start", timeout=5)

os.kill(p.pid, signal.SIGTERM)
p.join(timeout=3) # might take as long as 1 second to complete shutdown_process
assert not p.exitcode

e = q.get_nowait()
assert isinstance(e, QueryScriptCancelError)
assert e.return_code == expected_return_code
17 changes: 16 additions & 1 deletion tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import pytest

from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
from datachain.catalog.catalog import (
QUERY_SCRIPT_CANCELED_EXIT_CODE,
TerminationSignal,
)
from datachain.error import QueryScriptCancelError, QueryScriptRunError


Expand Down Expand Up @@ -63,3 +66,15 @@ def test_non_zero_exitcode(catalog, mock_popen):
catalog.query("pass")
assert e.value.return_code == 1
assert "Query script exited with error code 1" in str(e.value)


def test_shutdown_process_on_sigterm(mocker, catalog, mock_popen):
mock_popen.returncode = -2
mock_popen.wait.side_effect = [TerminationSignal(15)]
m = mocker.patch("datachain.catalog.catalog.shutdown_process", return_value=-2)

with pytest.raises(QueryScriptCancelError) as e:
catalog.query("pass", interrupt_timeout=0.1, terminate_timeout=0.2)
assert e.value.return_code == -2
assert "Query script was canceled by user" in str(e.value)
m.assert_called_once_with(mock_popen, 0.1, 0.2)

0 comments on commit c0f23b4

Please sign in to comment.