Skip to content

Commit

Permalink
Adapt SLURM engine for multi-node jobs (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends authored Nov 4, 2024
1 parent 3181d72 commit d256f43
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sisyphus/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _sis_path(self, path_type=None, task_id=None, abspath=False):

# Add task id as suffix
if task_id is not None:
path += ".%i" % task_id
path += f".{task_id}"

if abspath and not os.path.isabs(path):
path = os.path.join(gs.BASE_DIR, path)
Expand Down
39 changes: 31 additions & 8 deletions sisyphus/simple_linux_utility_for_resource_management_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def options(self, rqmt):
out.append("--time=%s" % task_time)
out.append("--export=all")

if rqmt.get("multi_node_slots", None):
if rqmt.get("multi_node_slots", 1) > 1:
out.append("--ntasks=%s" % rqmt["multi_node_slots"])
out.append("--nodes=%s" % rqmt["multi_node_slots"])

sbatch_args = rqmt.get("sbatch_args", [])
if isinstance(sbatch_args, str):
Expand Down Expand Up @@ -232,11 +233,15 @@ def submit_helper(self, call, logpath, rqmt, name, task_name, start_id, end_id,
:param int step_size:
"""
name = self.process_task_name(name)
sbatch_call = ["sbatch", "-J", name, "-o", logpath + "/%x.%A.%a", "--mail-type=None"]
out_log_file = f"{logpath}/%x.%A.%a"
if rqmt.get("multi_node_slots", 1) > 1:
out_log_file += ".%t"
sbatch_call = ["sbatch", "-J", name, "--mail-type=None"]
sbatch_call += self.options(rqmt)
sbatch_call += ["-o", f"{out_log_file}.batch"]
sbatch_call += ["-a", f"{start_id}-{end_id}:{step_size}"]
sbatch_call += [f"--wrap=srun -o {out_log_file} {' '.join(call)}"]

sbatch_call += ["-a", "%i-%i:%i" % (start_id, end_id, step_size)]
sbatch_call += ["--wrap=%s" % " ".join(call)]
while True:
try:
out, err, retval = self.system_call(sbatch_call)
Expand Down Expand Up @@ -393,20 +398,38 @@ def get_default_rqmt(self, task):

def init_worker(self, task):
# setup log file by linking to engine logfile
task_id = self.get_task_id(None)
logpath = os.path.relpath(task.path(gs.JOB_LOG, task_id))

# Naming ambiguity: sis "tasks" are what SLURM calls array jobs.
#
# SLURM tasks represent jobs that span multiple nodes at the same time
# (e.g. multi-node multi-GPU trainings consist of one SLURM task per node).
slurm_num_tasks = int(
next(filter(None, (os.getenv(var, None) for var in ["SLURM_NTASKS", "SLURM_NPROCS"])), "1")
)
slurm_task_id = int(os.getenv("SLURM_PROCID", "0"))
array_task_id = self.get_task_id(None)

# keep backwards compatibility: only change output file name for multi-SLURM-task jobs
log_suffix = str(array_task_id) + (f".{slurm_task_id}" if slurm_num_tasks > 1 else "")
logpath = os.path.relpath(task.path(gs.JOB_LOG, log_suffix))
if os.path.isfile(logpath):
os.unlink(logpath)

job_id = next(
filter(None, (os.getenv(name, None) for name in ["SLURM_JOB_ID", "SLURM_JOBID", "SLURM_ARRAY_JOB_ID"])), "0"
)
engine_logpath = (
os.path.dirname(logpath)
+ "/engine/"
+ os.getenv("SLURM_JOB_NAME")
+ "."
+ os.getenv("SLURM_ARRAY_JOB_ID")
+ job_id
+ "."
+ os.getenv("SLURM_ARRAY_TASK_ID")
+ os.getenv("SLURM_ARRAY_TASK_ID", "1")
)
if slurm_num_tasks > 1:
engine_logpath += f".{slurm_task_id}"

try:
if os.path.isfile(engine_logpath):
os.link(engine_logpath, logpath)
Expand Down

0 comments on commit d256f43

Please sign in to comment.