From de81d3f07d63f4abd14452f599c417dfa167bc7e Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Wed, 25 Sep 2024 09:49:11 +0200 Subject: [PATCH 1/7] add sanitization option to host outputs --- src/jobflow_remote/config/base.py | 5 ++ src/jobflow_remote/remote/host/base.py | 84 +++++++++++++++++++++++- src/jobflow_remote/remote/host/local.py | 12 +++- src/jobflow_remote/remote/host/remote.py | 9 ++- tests/integration/conftest.py | 19 ++++++ tests/integration/test_slurm.py | 27 +++++++- 6 files changed, 151 insertions(+), 5 deletions(-) diff --git a/src/jobflow_remote/config/base.py b/src/jobflow_remote/config/base.py index f228621d..ca915ec3 100644 --- a/src/jobflow_remote/config/base.py +++ b/src/jobflow_remote/config/base.py @@ -183,6 +183,11 @@ class WorkerBase(BaseModel): "username instead that from the list of job ids. May be necessary for some " "scheduler_type (e.g. SGE)", ) + sanitize_command: bool = Field( + default=False, + description="Sanitize the output of commands in case of failures due to spurious text produced" + "by the worker shell.", + ) model_config = ConfigDict(extra="forbid") @field_validator("scheduler_type") diff --git a/src/jobflow_remote/remote/host/base.py b/src/jobflow_remote/remote/host/base.py index ee10d574..ed62002c 100644 --- a/src/jobflow_remote/remote/host/base.py +++ b/src/jobflow_remote/remote/host/base.py @@ -1,6 +1,8 @@ from __future__ import annotations import abc +import logging +import re import traceback from typing import TYPE_CHECKING @@ -10,9 +12,26 @@ from pathlib import Path +logger = logging.getLogger(__name__) + +SANITIZE_KEY = r"_-_-_-_-_### JFREMOTE SANITIZE ###_-_-_-_-_" + + class BaseHost(MSONable): """Base Host class.""" + def __init__(self, sanitize: bool = False): + """ + Parameters + ---------- + sanitize + If True text a string will be prepended and appended to the output + of the commands, to ease the parsing and avoid failures due to spurious + text coming from the host shell. + """ + self.sanitize = sanitize + self._sanitize_regex: re.Pattern | None = None + @abc.abstractmethod def execute( self, @@ -28,7 +47,8 @@ def execute( Command to execute, as a str or list of str workdir: str or None path where the command will be executed. - + timeout + Timeout for the execution of the commands. """ raise NotImplementedError @@ -124,6 +144,68 @@ def interactive_login(self) -> bool: """ return False + @property + def sanitize_regex(self) -> re.Pattern: + """ + Regular expression to sanitize sensitive info in command outputs. + """ + if not self._sanitize_regex: + escaped_key = re.escape(SANITIZE_KEY) + self._sanitize_regex = re.compile( + f"{escaped_key}(.*?)(?:{escaped_key}|$)", re.DOTALL + ) + + return self._sanitize_regex + + def sanitize_command(self, cmd: str) -> str: + """ + Sanitizes a command by adding a prefix and suffix to the command string if + sanitization is enabled. + The prefix and suffix are the same and are used to mark the parts of the output + that should be sanitized. The prefix and suffix are defined by `SANITIZE_KEY`. + + Parameters + ---------- + cmd + The command string to be sanitized + + Returns + ------- + str + The sanitized command string + """ + if self.sanitize: + echo_cmd = f'echo -n "{SANITIZE_KEY}" | tee >(cat >&2)' + cmd = f"{echo_cmd};{cmd};{echo_cmd}" + return cmd + + def sanitize_output(self, output: str) -> str: + """ + Sanitizes the output of a command by selecting the section between the + SANITIZE_KEY strings. + If the second instance of the key is not found, the part of the output after the key is returned. + If the key is not present, the entire output is returned. + + Parameters + ---------- + output + The output of the command to be sanitized + + Returns + ------- + str + The sanitized output + """ + if self.sanitize: + match = self.sanitize_regex.search(output) + if not match: + logger.warning( + f"Even if sanitization was required, there was no match for the output: {output}. Returning the complete output" + ) + return output + return match.group(1) + return output + class HostError(Exception): pass diff --git a/src/jobflow_remote/remote/host/local.py b/src/jobflow_remote/remote/host/local.py index f20cdc21..068d9169 100644 --- a/src/jobflow_remote/remote/host/local.py +++ b/src/jobflow_remote/remote/host/local.py @@ -12,8 +12,9 @@ class LocalHost(BaseHost): - def __init__(self, timeout_execute: int = None) -> None: + def __init__(self, timeout_execute: int = None, sanitize: bool = False) -> None: self.timeout_execute = timeout_execute + super().__init__(sanitize=sanitize) def __eq__(self, other): return isinstance(other, LocalHost) @@ -34,6 +35,10 @@ def execute( ---------- command: str or list of str Command to execute, as a str or list of str + workdir: str or None + path where the command will be executed. + timeout + Timeout for the execution of the commands. Returns ------- @@ -46,13 +51,16 @@ def execute( """ if isinstance(command, (list, tuple)): command = " ".join(command) + command = self.sanitize_command(command) workdir = str(workdir) if workdir else Path.cwd() timeout = timeout or self.timeout_execute with cd(workdir): proc = subprocess.run( command, capture_output=True, shell=True, timeout=timeout, check=False ) - return proc.stdout.decode(), proc.stderr.decode(), proc.returncode + stdout = self.sanitize_output(proc.stdout.decode()) + stderr = self.sanitize_output(proc.stderr.decode()) + return stdout, stderr, proc.returncode def mkdir( self, directory: str | Path, recursive: bool = True, exist_ok: bool = True diff --git a/src/jobflow_remote/remote/host/remote.py b/src/jobflow_remote/remote/host/remote.py index 48595ebc..ae797368 100644 --- a/src/jobflow_remote/remote/host/remote.py +++ b/src/jobflow_remote/remote/host/remote.py @@ -42,6 +42,7 @@ def __init__( login_shell=True, retry_on_closed_connection=True, interactive_login=False, + sanitize: bool = False, ) -> None: self.host = host self.user = user @@ -59,6 +60,7 @@ def __init__( self.retry_on_closed_connection = retry_on_closed_connection self._interactive_login = interactive_login self._create_connection() + super().__init__(sanitize=sanitize) def _create_connection(self) -> None: if self.interactive_login: @@ -175,6 +177,8 @@ def execute( if isinstance(command, (list, tuple)): command = " ".join(command) + command = self.sanitize_command(command) + # TODO: check if this works: if not workdir: workdir = "." @@ -201,7 +205,10 @@ def execute( timeout=timeout, ) - return out.stdout, out.stderr, out.exited + stdout = self.sanitize_output(out.stdout) + stderr = self.sanitize_output(out.stderr) + + return stdout, stderr, out.exited def mkdir( self, directory: str | Path, recursive: bool = True, exist_ok: bool = True diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 52137b85..a36c983f 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -223,6 +223,12 @@ def write_tmp_settings( work_dir=str(workdir), resources={}, ), + "test_sanitize_local_worker": dict( + type="local", + scheduler_type="shell", + work_dir=str(workdir), + resources={}, + ), "test_remote_worker": dict( type="remote", host="localhost", @@ -273,6 +279,19 @@ def write_tmp_settings( resources={}, max_jobs=2, ), + "test_sanitize_remote_worker": dict( + type="remote", + host="localhost", + port=slurm_ssh_port, + scheduler_type="slurm", + work_dir="/home/jobflow/jfr", + user="jobflow", + password="jobflow", + pre_run="source /home/jobflow/.venv/bin/activate", + resources={"partition": "debug", "ntasks": 1, "time": "00:01:00"}, + connect_kwargs={"allow_agent": False, "look_for_keys": False}, + sanitize_command=True, + ), }, exec_config={"test": {"export": {"TESTING_ENV_VAR": random_project_name}}}, runner=dict( diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index 9602f7d8..ec338211 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -15,7 +15,7 @@ def test_project_init(random_project_name) -> None: assert len(cm.projects) == 1 assert cm.projects[random_project_name] project = cm.get_project() - assert len(project.workers) == 5 + assert len(project.workers) == 7 def test_paramiko_ssh_connection(job_controller, slurm_ssh_port) -> None: @@ -39,7 +39,12 @@ def test_project_check(job_controller, capsys) -> None: expected = [ "✓ Worker test_local_worker", + "✓ Worker test_sanitize_local_worker", "✓ Worker test_remote_worker", + "✓ Worker test_remote_limited_worker", + "✓ Worker test_batch_remote_worker", + "✓ Worker test_max_jobs_worker", + "✓ Worker test_sanitize_remote_worker", "✓ Jobstore", "✓ Queue store", ] @@ -404,3 +409,23 @@ def test_priority(worker, job_controller) -> None: jobs_info = sorted(jobs_info, key=lambda x: x.priority, reverse=True) for i in range(len(jobs_info) - 1): assert jobs_info[i].end_time < jobs_info[i + 1].start_time + +@pytest.mark.parametrize( + "worker", + ["test_sanitize_local_worker", "test_sanitize_remote_worker"], +) +def test_sanitize(worker, job_controller): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add + + flow = Flow([add(1, 2)]) + submit_flow(flow, worker=worker) + + runner = Runner() + runner.run_one_job() + + assert job_controller.count_jobs(states=JobState.COMPLETED) == 1 From 5fcda6075460889a19a4fe11cab54cae0bd114af Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Wed, 25 Sep 2024 15:06:15 +0200 Subject: [PATCH 2/7] add mock tests --- tests/db/remote/host/test_local.py | 33 ++++++++++++++++++++++++ tests/db/remote/host/test_remote.py | 40 +++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 tests/db/remote/host/test_local.py create mode 100644 tests/db/remote/host/test_remote.py diff --git a/tests/db/remote/host/test_local.py b/tests/db/remote/host/test_local.py new file mode 100644 index 00000000..0157f565 --- /dev/null +++ b/tests/db/remote/host/test_local.py @@ -0,0 +1,33 @@ +from unittest.mock import patch + + +@patch("subprocess.run") +def test_sanitize(mock_run): + from jobflow_remote.remote.host.base import SANITIZE_KEY + from jobflow_remote.remote.host.local import LocalHost + + lh = LocalHost(sanitize=True) + + cmd = "echo 'test'" + + echo_cmd = f'echo -n "{SANITIZE_KEY}" | tee >(cat >&2)' + expected_cmd = f"{echo_cmd};{cmd};{echo_cmd}" + mock_stdout = f"SOME NOISE --{SANITIZE_KEY}test{SANITIZE_KEY}SOME appended TEXT" + + # Configure the mock + mock_run.return_value.returncode = 0 + mock_run.return_value.stdout = mock_stdout.encode() + mock_run.return_value.stderr = b"" + + stdout, stderr, _ = lh.execute(cmd) + + mock_run.assert_called_once_with( + expected_cmd, + capture_output=True, + shell=True, # noqa: S604 + timeout=None, + check=False, + ) + + assert stdout == "test" + assert stderr == "" diff --git a/tests/db/remote/host/test_remote.py b/tests/db/remote/host/test_remote.py new file mode 100644 index 00000000..d31741dd --- /dev/null +++ b/tests/db/remote/host/test_remote.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + + +@patch("fabric.Connection.run") +@patch("fabric.Connection.cd") +def test_sanitize(mock_cd, mock_run): + from jobflow_remote.remote.host.base import SANITIZE_KEY + from jobflow_remote.remote.host.remote import RemoteHost + + rh = RemoteHost( + host="localhost", + retry_on_closed_connection=False, + sanitize=True, + shell_cmd=None, + ) + rh._check_connected = lambda: True + + cmd = "echo 'test'" + + echo_cmd = f'echo -n "{SANITIZE_KEY}" | tee >(cat >&2)' + expected_cmd = f"{echo_cmd};{cmd};{echo_cmd}" + mock_stdout = f"SOME NOISE --{SANITIZE_KEY}test{SANITIZE_KEY}SOME appended TEXT" + + # Configure the mock + mock_cd.return_value.__enter__ = ( + MagicMock() + ) # This makes the context manager do nothing + mock_cd.return_value.__exit__ = MagicMock() + mock_run.return_value.stdout = mock_stdout + mock_run.return_value.stderr = "" + + # Call the function that uses subprocess.run + stdout, stderr, _ = rh.execute(cmd) + + # Assert that subprocess.run was called with the expected arguments + mock_run.assert_called_once_with(expected_cmd, timeout=None, hide=True, warn=True) + + # Assert on the result of your function + assert stdout == "test" + assert stderr == "" From 586e7cfb3ef7d2177df5035d30913f2265cdbe57 Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Fri, 27 Sep 2024 09:33:25 +0200 Subject: [PATCH 3/7] fix sanitization --- src/jobflow_remote/config/base.py | 5 ++++- tests/integration/conftest.py | 1 + tests/integration/test_slurm.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/jobflow_remote/config/base.py b/src/jobflow_remote/config/base.py index ca915ec3..0d4653a0 100644 --- a/src/jobflow_remote/config/base.py +++ b/src/jobflow_remote/config/base.py @@ -257,7 +257,9 @@ def get_host(self) -> BaseHost: ------- The LocalHost. """ - return LocalHost(timeout_execute=self.timeout_execute) + return LocalHost( + timeout_execute=self.timeout_execute, sanitize=self.sanitize_command + ) @property def cli_info(self) -> dict: @@ -407,6 +409,7 @@ def get_host(self) -> BaseHost: shell_cmd=self.shell_cmd, login_shell=self.login_shell, interactive_login=self.interactive_login, + sanitize=self.sanitize_command, ) @property diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a36c983f..634ffe2a 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -228,6 +228,7 @@ def write_tmp_settings( scheduler_type="shell", work_dir=str(workdir), resources={}, + sanitize_command=True, ), "test_remote_worker": dict( type="remote", diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index ec338211..af572cb6 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -422,6 +422,8 @@ def test_sanitize(worker, job_controller): from jobflow_remote.jobs.state import JobState from jobflow_remote.testing import add + assert job_controller.project.workers[worker].get_host().sanitize is True + flow = Flow([add(1, 2)]) submit_flow(flow, worker=worker) From a72ddd3b3e2dc0bb114de30a0f3bf0a8f77d3638 Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Fri, 27 Sep 2024 10:19:59 +0200 Subject: [PATCH 4/7] fix test --- tests/integration/test_slurm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index af572cb6..6581d8d1 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -15,7 +15,7 @@ def test_project_init(random_project_name) -> None: assert len(cm.projects) == 1 assert cm.projects[random_project_name] project = cm.get_project() - assert len(project.workers) == 7 + assert len(project.workers) == 6 def test_paramiko_ssh_connection(job_controller, slurm_ssh_port) -> None: From ef7e3a5b0f2b0f8446ceee8fe81cf9562562666e Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Mon, 30 Sep 2024 11:48:32 +0200 Subject: [PATCH 5/7] update sanitization to fit more cases --- src/jobflow_remote/remote/host/base.py | 22 ++++++++++++++++++---- tests/integration/test_slurm.py | 5 +++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/jobflow_remote/remote/host/base.py b/src/jobflow_remote/remote/host/base.py index ed62002c..7d7a17c9 100644 --- a/src/jobflow_remote/remote/host/base.py +++ b/src/jobflow_remote/remote/host/base.py @@ -94,8 +94,19 @@ def test(self) -> str | None: try: cmd = "echo 'test'" stdout, stderr, returncode = self.execute(cmd) - if returncode != 0 or stdout.strip() != "test": - msg = f"Command was executed but some error occurred.\nstdoud: {stdout}\nstderr: {stderr}" + if returncode != 0: + msg = f"Command was executed but return code was different from zero.\nstdoud: {stdout}\nstderr: {stderr}" + elif stdout.strip() != "test" or stderr.strip() != "": + msg = ( + "Command was executed but the output is not the expected one (i.e. a single 'test' " + f"string in both stdout and stderr).\nstdoud: {stdout}\nstderr: {stderr}" + ) + if not self.sanitize: + msg += ( + "\nIf the output contains additional text the problem may be solved by setting " + "the 'sanitize_command' option to True in the project configuration." + ) + except Exception: exc = traceback.format_exc() msg = f"Error while executing command:\n {exc}" @@ -151,8 +162,11 @@ def sanitize_regex(self) -> re.Pattern: """ if not self._sanitize_regex: escaped_key = re.escape(SANITIZE_KEY) + # Optionally match the newline that comes from the "echo" command. + # The -n option for echo to suppress the newline seems to not be + # supported on all systems self._sanitize_regex = re.compile( - f"{escaped_key}(.*?)(?:{escaped_key}|$)", re.DOTALL + f"{escaped_key}\r?\n?(.*?)(?:{escaped_key}\r?\n?|$)", re.DOTALL ) return self._sanitize_regex @@ -175,7 +189,7 @@ def sanitize_command(self, cmd: str) -> str: The sanitized command string """ if self.sanitize: - echo_cmd = f'echo -n "{SANITIZE_KEY}" | tee >(cat >&2)' + echo_cmd = f'echo "{SANITIZE_KEY}" | tee /dev/stderr' cmd = f"{echo_cmd};{cmd};{echo_cmd}" return cmd diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index 6581d8d1..9cd1afa1 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -15,7 +15,7 @@ def test_project_init(random_project_name) -> None: assert len(cm.projects) == 1 assert cm.projects[random_project_name] project = cm.get_project() - assert len(project.workers) == 6 + assert len(project.workers) == 7 def test_paramiko_ssh_connection(job_controller, slurm_ssh_port) -> None: @@ -48,7 +48,7 @@ def test_project_check(job_controller, capsys) -> None: "✓ Jobstore", "✓ Queue store", ] - run_check_cli(["project", "check"], required_out=expected) + run_check_cli(["project", "check", "-e"], required_out=expected) @pytest.mark.parametrize( @@ -410,6 +410,7 @@ def test_priority(worker, job_controller) -> None: for i in range(len(jobs_info) - 1): assert jobs_info[i].end_time < jobs_info[i + 1].start_time + @pytest.mark.parametrize( "worker", ["test_sanitize_local_worker", "test_sanitize_remote_worker"], From 3e9f4f834fc19e36e5927386b3d434fc35954174 Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Mon, 30 Sep 2024 13:03:45 +0200 Subject: [PATCH 6/7] fix tests --- tests/db/remote/host/test_local.py | 4 ++-- tests/db/remote/host/test_remote.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/db/remote/host/test_local.py b/tests/db/remote/host/test_local.py index 0157f565..254caf96 100644 --- a/tests/db/remote/host/test_local.py +++ b/tests/db/remote/host/test_local.py @@ -10,9 +10,9 @@ def test_sanitize(mock_run): cmd = "echo 'test'" - echo_cmd = f'echo -n "{SANITIZE_KEY}" | tee >(cat >&2)' + echo_cmd = f'echo "{SANITIZE_KEY}" | tee /dev/stderr' expected_cmd = f"{echo_cmd};{cmd};{echo_cmd}" - mock_stdout = f"SOME NOISE --{SANITIZE_KEY}test{SANITIZE_KEY}SOME appended TEXT" + mock_stdout = f"SOME NOISE --{SANITIZE_KEY}\ntest{SANITIZE_KEY}\nSOME appended TEXT" # Configure the mock mock_run.return_value.returncode = 0 diff --git a/tests/db/remote/host/test_remote.py b/tests/db/remote/host/test_remote.py index d31741dd..f8709cac 100644 --- a/tests/db/remote/host/test_remote.py +++ b/tests/db/remote/host/test_remote.py @@ -17,9 +17,9 @@ def test_sanitize(mock_cd, mock_run): cmd = "echo 'test'" - echo_cmd = f'echo -n "{SANITIZE_KEY}" | tee >(cat >&2)' + echo_cmd = f'echo "{SANITIZE_KEY}" | tee /dev/stderr' expected_cmd = f"{echo_cmd};{cmd};{echo_cmd}" - mock_stdout = f"SOME NOISE --{SANITIZE_KEY}test{SANITIZE_KEY}SOME appended TEXT" + mock_stdout = f"SOME NOISE --{SANITIZE_KEY}\ntest{SANITIZE_KEY}\nSOME appended TEXT" # Configure the mock mock_cd.return_value.__enter__ = ( From 6f1a24a55cc841207c37003d8b9c1e7e23329efd Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Tue, 1 Oct 2024 11:08:59 +0200 Subject: [PATCH 7/7] fix estimated run time --- src/jobflow_remote/jobs/data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/jobflow_remote/jobs/data.py b/src/jobflow_remote/jobs/data.py index fedb289a..3ffa1499 100644 --- a/src/jobflow_remote/jobs/data.py +++ b/src/jobflow_remote/jobs/data.py @@ -177,9 +177,7 @@ def estimated_run_time(self) -> Optional[float]: The estimated run time in seconds. """ if self.start_time: - return ( - datetime.now(tz=self.start_time.tzinfo) - self.start_time - ).total_seconds() + return (datetime.utcnow() - self.start_time).total_seconds() return None