From d256f43b69c0dd668b52822ec0016cb8479e5efd Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 4 Nov 2024 15:54:36 +0100 Subject: [PATCH] Adapt SLURM engine for multi-node jobs (#212) --- sisyphus/job.py | 2 +- ..._utility_for_resource_management_engine.py | 39 +++++++++++++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/sisyphus/job.py b/sisyphus/job.py index 5243bdb..516b9d9 100644 --- a/sisyphus/job.py +++ b/sisyphus/job.py @@ -398,7 +398,7 @@ def _sis_path(self, path_type=None, task_id=None, abspath=False): # Add task id as suffix if task_id is not None: - path += ".%i" % task_id + path += f".{task_id}" if abspath and not os.path.isabs(path): path = os.path.join(gs.BASE_DIR, path) diff --git a/sisyphus/simple_linux_utility_for_resource_management_engine.py b/sisyphus/simple_linux_utility_for_resource_management_engine.py index ff0e71a..7815666 100644 --- a/sisyphus/simple_linux_utility_for_resource_management_engine.py +++ b/sisyphus/simple_linux_utility_for_resource_management_engine.py @@ -173,8 +173,9 @@ def options(self, rqmt): out.append("--time=%s" % task_time) out.append("--export=all") - if rqmt.get("multi_node_slots", None): + if rqmt.get("multi_node_slots", 1) > 1: out.append("--ntasks=%s" % rqmt["multi_node_slots"]) + out.append("--nodes=%s" % rqmt["multi_node_slots"]) sbatch_args = rqmt.get("sbatch_args", []) if isinstance(sbatch_args, str): @@ -232,11 +233,15 @@ def submit_helper(self, call, logpath, rqmt, name, task_name, start_id, end_id, :param int step_size: """ name = self.process_task_name(name) - sbatch_call = ["sbatch", "-J", name, "-o", logpath + "/%x.%A.%a", "--mail-type=None"] + out_log_file = f"{logpath}/%x.%A.%a" + if rqmt.get("multi_node_slots", 1) > 1: + out_log_file += ".%t" + sbatch_call = ["sbatch", "-J", name, "--mail-type=None"] sbatch_call += self.options(rqmt) + sbatch_call += ["-o", f"{out_log_file}.batch"] + sbatch_call += ["-a", f"{start_id}-{end_id}:{step_size}"] + sbatch_call += [f"--wrap=srun -o {out_log_file} {' '.join(call)}"] - sbatch_call += ["-a", "%i-%i:%i" % (start_id, end_id, step_size)] - sbatch_call += ["--wrap=%s" % " ".join(call)] while True: try: out, err, retval = self.system_call(sbatch_call) @@ -393,20 +398,38 @@ def get_default_rqmt(self, task): def init_worker(self, task): # setup log file by linking to engine logfile - task_id = self.get_task_id(None) - logpath = os.path.relpath(task.path(gs.JOB_LOG, task_id)) + + # Naming ambiguity: sis "tasks" are what SLURM calls array jobs. + # + # SLURM tasks represent jobs that span multiple nodes at the same time + # (e.g. multi-node multi-GPU trainings consist of one SLURM task per node). + slurm_num_tasks = int( + next(filter(None, (os.getenv(var, None) for var in ["SLURM_NTASKS", "SLURM_NPROCS"])), "1") + ) + slurm_task_id = int(os.getenv("SLURM_PROCID", "0")) + array_task_id = self.get_task_id(None) + + # keep backwards compatibility: only change output file name for multi-SLURM-task jobs + log_suffix = str(array_task_id) + (f".{slurm_task_id}" if slurm_num_tasks > 1 else "") + logpath = os.path.relpath(task.path(gs.JOB_LOG, log_suffix)) if os.path.isfile(logpath): os.unlink(logpath) + job_id = next( + filter(None, (os.getenv(name, None) for name in ["SLURM_JOB_ID", "SLURM_JOBID", "SLURM_ARRAY_JOB_ID"])), "0" + ) engine_logpath = ( os.path.dirname(logpath) + "/engine/" + os.getenv("SLURM_JOB_NAME") + "." - + os.getenv("SLURM_ARRAY_JOB_ID") + + job_id + "." - + os.getenv("SLURM_ARRAY_TASK_ID") + + os.getenv("SLURM_ARRAY_TASK_ID", "1") ) + if slurm_num_tasks > 1: + engine_logpath += f".{slurm_task_id}" + try: if os.path.isfile(engine_logpath): os.link(engine_logpath, logpath)