Skip to content

Commit

Permalink
fix(agent): fix cwd on remote job submission (#544) (#545)
Browse files Browse the repository at this point in the history
  • Loading branch information
fschuch authored Apr 19, 2024
1 parent 9b932b8 commit 4ec7646
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 35 deletions.
2 changes: 2 additions & 0 deletions jobbergate-agent/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ This file keeps track of all notable changes to jobbergate-agent

## Unreleased

- Patched cwd issues on remote job submission to avoid permission denied errors on the folder

## 5.0.0 -- 2024-04-18

- Added logic to update slurm job status at job submission time [PENG-2193]
Expand Down
37 changes: 22 additions & 15 deletions jobbergate-agent/jobbergate_agent/jobbergate/submit.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import asyncio
import functools
import os
import pwd
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory

from buzz import DoExceptParams
from jobbergate_core.tools.sbatch import InfoHandler, SubmissionHandler, inject_sbatch_params
from jobbergate_core.tools.sbatch import InfoHandler, SubmissionHandler, SubprocessHandler, inject_sbatch_params
from loguru import logger

from jobbergate_agent.clients.cluster_api import backend_client as jobbergate_api_client
Expand Down Expand Up @@ -162,19 +162,26 @@ async def mark_as_rejected(job_submission_id: int, report_message: str):
response.raise_for_status()


@functools.lru_cache(maxsize=64)
def run_as_user(username):
"""Provide to subprocess.run a way to run a command as the user."""
# Get the uid and gid from the username
pwan = pwd.getpwnam(username)
uid = pwan.pw_uid
gid = pwan.pw_gid
@dataclass
class SubprocessAsUserHandler(SubprocessHandler):
"""Subprocess handler that runs as a given user."""

def preexec(): # Function to be run in the child process before the subprocess call
os.setgid(gid)
os.setuid(uid)
username: str

return preexec
def __post_init__(self):
pwan = pwd.getpwnam(self.username)
self.uid = pwan.pw_uid
self.gid = pwan.pw_gid

def run(self, *args, **kwargs):
kwargs.update(user=self.uid, group=self.gid)
# Tests indicate that the change on the working directory precedes the change of user on the subprocess.
# With that, the user running the agent can face permission denied errors on cwd,
# depending on the setting on the filesystem and permissions on the directory.
# To avoid this, we change the working directory after changing to the submitter user using preexec_fn.
if cwd := kwargs.pop("cwd", None):
kwargs["preexec_fn"] = lambda: os.chdir(cwd)
return super().run(*args, **kwargs)


async def submit_job_script(
Expand Down Expand Up @@ -208,7 +215,7 @@ async def _reject_handler(params: DoExceptParams):
logger.debug(f"Fetching username for email '{email}' with mapper {mapper_class_name}")
username = user_mapper[email]
logger.debug(f"Using local slurm user '{username}' for job submission")
preexec_fn = run_as_user(username)
subprocess_handler = SubprocessAsUserHandler(username)

submit_dir = pending_job_submission.execution_directory or SETTINGS.DEFAULT_SLURM_WORK_DIR
if not submit_dir.exists() or not submit_dir.is_absolute():
Expand All @@ -219,7 +226,7 @@ async def _reject_handler(params: DoExceptParams):
sbatch_handler = SubmissionHandler(
sbatch_path=SETTINGS.SBATCH_PATH,
submission_directory=submit_dir,
preexec_fn=preexec_fn,
subprocess_handler=subprocess_handler,
)

with TemporaryDirectory(prefix=str(pending_job_submission.id), suffix=pending_job_submission.name) as tmp_dir:
Expand Down
2 changes: 2 additions & 0 deletions jobbergate-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ This file keeps track of all notable changes to jobbergate-core

## Unreleased

- Refactored the command handler that interfaces with sbatch and scontrol

## 5.0.0 -- 2024-04-18

- Dropped support for Python 3.8 and 3.9
Expand Down
26 changes: 14 additions & 12 deletions jobbergate-core/jobbergate_core/tools/sbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, ClassVar, Sequence
from typing import Any, ClassVar, Sequence

from buzz import check_expressions
from loguru import logger
Expand Down Expand Up @@ -51,15 +51,13 @@ def inject_sbatch_params(job_script_data_as_string: str, sbatch_params: list[str
return new_job_script_data_as_string


@dataclass(frozen=True)
class Command:
preexec_fn: Callable | None = None
@dataclass
class SubprocessHandler:

def run_command(self, cmd: Sequence[str], **kwargs) -> subprocess.CompletedProcess:
"""Runs a command as the user."""
def run(self, cmd: Sequence[str], **kwargs) -> subprocess.CompletedProcess:
logger.debug("Running command '{}' with kwargs: {}", " ".join(cmd), kwargs)
try:
result = subprocess.run(cmd, preexec_fn=self.preexec_fn, check=True, shell=False, **kwargs)
result = subprocess.run(cmd, check=True, shell=False, **kwargs)
logger.trace("Command returned code {} with result: {}", result.returncode, result.stdout)
return result
except subprocess.CalledProcessError as e:
Expand All @@ -69,10 +67,11 @@ def run_command(self, cmd: Sequence[str], **kwargs) -> subprocess.CompletedProce


@dataclass(frozen=True)
class InfoHandler(Command):
class InfoHandler:
"""Get info from jobs on the cluster."""

scontrol_path: Path = Path("/usr/bin/scontrol")
subprocess_handler: SubprocessHandler = field(default_factory=SubprocessHandler)

def __post_init__(self):
with check_expressions("Check paths", raise_exc_class=ValueError) as check:
Expand All @@ -88,7 +87,7 @@ def get_job_info(self, slurm_id: int) -> dict[str, Any]:
shlex.quote(str(slurm_id)),
"--json",
)
completed_process = self.run_command(command, capture_output=True, text=True)
completed_process = self.subprocess_handler.run(command, capture_output=True, text=True)
data = json.loads(completed_process.stdout)
try:
job_info = data["jobs"][0]
Expand All @@ -105,11 +104,12 @@ def get_job_info(self, slurm_id: int) -> dict[str, Any]:


@dataclass(frozen=True)
class SubmissionHandler(Command):
class SubmissionHandler:
"""Submits sbatch jobs to the cluster."""

sbatch_path: Path = Path("/usr/bin/sbatch")
submission_directory: Path = field(default_factory=Path.cwd)
subprocess_handler: SubprocessHandler = field(default_factory=SubprocessHandler)

sbatch_output_parser: ClassVar[re.Pattern] = re.compile(r"^(?P<id>\d+)(,(?P<cluster_name>.+))?$")

Expand All @@ -128,7 +128,9 @@ def submit_job(self, job_script_path: Path) -> int:
job_script_path.as_posix(),
)

completed_process = self.run_command(command, cwd=self.submission_directory, capture_output=True, text=True)
completed_process = self.subprocess_handler.run(
command, cwd=self.submission_directory, capture_output=True, text=True
)

if match := self.sbatch_output_parser.match(completed_process.stdout):
slurm_id = int(match.group("id"))
Expand All @@ -145,7 +147,7 @@ def copy_file_to_submission_directory(self, source_file: Path) -> Path:
command = ("tee", destination_file.as_posix())
try:
with source_file.open("rb") as source:
self.run_command(
self.subprocess_handler.run(
command,
stdin=source,
stdout=subprocess.DEVNULL,
Expand Down
8 changes: 0 additions & 8 deletions jobbergate-core/tests/tools/test_sbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@ def test_run__success(self, mocker, sbatch_path, tmp_path):
response = subprocess.CompletedProcess(args=[], stdout="123", returncode=0)
mocked_run = mocker.patch("jobbergate_core.tools.sbatch.subprocess.run", return_value=response)

def dummy_preexec_fn():
pass

sbatch_handler = SubmissionHandler(
sbatch_path=sbatch_path,
submission_directory=tmp_path,
preexec_fn=dummy_preexec_fn,
)

job_script_path = tmp_path / "file.sh"
Expand All @@ -57,7 +53,6 @@ def dummy_preexec_fn():
"--parsable",
job_script_path.as_posix(),
),
preexec_fn=dummy_preexec_fn,
check=True,
shell=False,
cwd=tmp_path,
Expand Down Expand Up @@ -128,7 +123,6 @@ def test_get_job_info__success(self, mocker, scontrol_path):
"123",
"--json",
),
preexec_fn=None,
check=True,
shell=False,
capture_output=True,
Expand All @@ -153,7 +147,6 @@ def test_get_job_info__failed_to_parse(self, mocker, scontrol_path):
"123",
"--json",
),
preexec_fn=None,
check=True,
shell=False,
capture_output=True,
Expand All @@ -176,7 +169,6 @@ def test_get_job_info__not_fount(self, mocker, sbatch_path, scontrol_path, tmp_p
"123",
"--json",
),
preexec_fn=None,
check=True,
shell=False,
capture_output=True,
Expand Down

0 comments on commit 4ec7646

Please sign in to comment.