Skip to content

Commit

Permalink
Adapt SLURM engine for multi-node jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Oct 24, 2024
1 parent c7de85e commit b575f3a
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions sisyphus/simple_linux_utility_for_resource_management_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def options(self, rqmt):
out.append("--time=%s" % task_time)
out.append("--export=all")

if rqmt.get("multi_node_slots", None):
if rqmt.get("multi_node_slots", 1) > 1:
out.append("--ntasks=%s" % rqmt["multi_node_slots"])
out.append("--nodes=%s" % rqmt["multi_node_slots"])

sbatch_args = rqmt.get("sbatch_args", [])
if isinstance(sbatch_args, str):
Expand Down Expand Up @@ -232,11 +233,18 @@ def submit_helper(self, call, logpath, rqmt, name, task_name, start_id, end_id,
:param int step_size:
"""
name = self.process_task_name(name)
sbatch_call = ["sbatch", "-J", name, "-o", logpath + "/%x.%A.%a", "--mail-type=None"]
sbatch_call = ["sbatch", "-J", name]
out_log_file = logpath + "/%x.%A.%t.%a"
if rqmt.get("multi_node_slots", 1) <= 1:
sbatch_call += ["-o", out_log_file]
sbatch_call += ["--mail-type=None"]
sbatch_call += self.options(rqmt)

sbatch_call += ["-a", "%i-%i:%i" % (start_id, end_id, step_size)]
sbatch_call += ["--wrap=%s" % " ".join(call)]
if rqmt.get("multi_node_slots", 1) > 1:
sbatch_call += [f"--wrap=srun -o {out_log_file} {' '.join(call)}"]
else:
sbatch_call += [f"--wrap={' '.join(call)}"]

while True:
try:
out, err, retval = self.system_call(sbatch_call)
Expand Down Expand Up @@ -393,19 +401,27 @@ def get_default_rqmt(self, task):

def init_worker(self, task):
# setup log file by linking to engine logfile
task_id = self.get_task_id(None)
logpath = os.path.relpath(task.path(gs.JOB_LOG, task_id))

# for now assume we only have one log file per SLURM task
if os.getenv("SLURM_PROCID", "0") != "0":
return

array_task_id = self.get_task_id(None)
logpath = os.path.relpath(task.path(gs.JOB_LOG, array_task_id))
if os.path.isfile(logpath):
os.unlink(logpath)

job_ids = (os.getenv(name, None) for name in ["SLURM_JOB_ID", "SLURM_JOBID", "SLURM_ARRAY_JOB_ID"])
engine_logpath = (
os.path.dirname(logpath)
+ "/engine/"
+ os.getenv("SLURM_JOB_NAME")
+ "."
+ os.getenv("SLURM_ARRAY_JOB_ID")
+ next(filter(None, job_ids), "0")
+ "."
+ os.getenv("SLURM_PROCID", "0")
+ "."
+ os.getenv("SLURM_ARRAY_TASK_ID")
+ os.getenv("SLURM_ARRAY_TASK_ID", "1")
)
try:
if os.path.isfile(engine_logpath):
Expand Down

0 comments on commit b575f3a

Please sign in to comment.