From d530dc7a90bc3047a9e6899f10742a15aa6ef9ca Mon Sep 17 00:00:00 2001 From: Lin Guo Date: Mon, 20 Jan 2025 23:07:33 -0800 Subject: [PATCH] Take in user-defined `batch_submit` This allows user to define their own `batch_submit`, and the final script will tack on the part of storing the job id. To preserve the user-defined variable, it stores a `var` into `_old_var`, before overriding its value as part of the `register_template`. --- lib/ramble/ramble/application.py | 4 ++++ .../slurm_workflow_manager.py | 19 ++++++++++++------- .../workflow_managers/slurm/batch_submit.tpl | 2 +- .../slurm/workflow_manager.py | 17 +++++++++++++++++ 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/lib/ramble/ramble/application.py b/lib/ramble/ramble/application.py index 050156f42..0d259e15f 100644 --- a/lib/ramble/ramble/application.py +++ b/lib/ramble/ramble/application.py @@ -2389,6 +2389,10 @@ def _define_object_template_vars(self, workspace): for obj, tpl_config in self._object_templates(workspace): var_name = tpl_config["var_name"] if var_name is not None: + if var_name in self.variables: + old_var = f"_old_{var_name}" + self.variables[old_var] = self.variables[var_name] + self.keywords.update_keys({old_var: var_attr}) self.variables[var_name] = tpl_config["dest_path"] self.keywords.update_keys({var_name: var_attr}) if callable(getattr(obj, "template_render_vars", None)): diff --git a/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py b/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py index d5840cf06..e9facc10d 100644 --- a/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py +++ b/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py @@ -29,8 +29,7 @@ def test_slurm_workflow(): variants: workflow_manager: '{wm_name}' variables: - # This batch_submit is overridden with slurm workflow manager - batch_submit: echo {wm_name} + batch_submit: sbatch {execute_experiment} mpi_command: mpirun -n {n_ranks} -hostfile hostfile processes_per_node: 1 n_nodes: 1 @@ -60,13 +59,16 @@ def test_slurm_workflow(): ws._re_read() workspace("setup", "--dry-run", global_args=["-D", ws.root]) - # assert the batch_submit is overridden, pointing to the generated script + # Assert on the all_experiments script all_exec_file = os.path.join(ws.root, "all_experiments") with open(all_exec_file) as f: content = f.read() - assert "echo None" in content - assert "echo slurm" not in content - assert os.path.join("hostname", "local", "test_slurm", "batch_submit") in content + batch_submit_path = os.path.join( + ws.experiment_dir, "hostname", "local", "test_slurm", "batch_submit" + ) + assert batch_submit_path in content + # The sbatch is embedded in the batch_submit_path script instead + assert f"sbatch {batch_submit_path}" not in content # Assert on no workflow manager path = os.path.join(ws.experiment_dir, "hostname", "local", "test_None") @@ -86,8 +88,11 @@ def test_slurm_workflow(): assert "batch_wait" in files with open(os.path.join(path, "batch_submit")) as f: content = f.read() - assert "slurm_experiment_sbatch" in content + # Assert the user-defined `batch_submit` is included + assert "slurm_experiment_sbatch" not in content + assert "execute_experiment" in content assert ".slurm_job" in content + assert "sbatch" in content with open(os.path.join(path, "slurm_experiment_sbatch")) as f: content = f.read() assert "scontrol show hostnames" in content diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/batch_submit.tpl b/var/ramble/repos/builtin/workflow_managers/slurm/batch_submit.tpl index 3bd9598a9..695d3d429 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/batch_submit.tpl +++ b/var/ramble/repos/builtin/workflow_managers/slurm/batch_submit.tpl @@ -1,2 +1,2 @@ #!/bin/bash -sbatch {slurm_experiment_sbatch} | tee >(awk '{print $NF}' > {experiment_run_dir}/.slurm_job) +{batch_submit_cmd} | tee >(awk '{print $NF}' > {experiment_run_dir}/.slurm_job) diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py index 0350436a2..a83c733e8 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py +++ b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py @@ -95,8 +95,25 @@ def __init__(self, file_path): name="batch_submit", src_path="batch_submit.tpl", dest_path="batch_submit", + extra_vars_func="batch_submit_vars", ) + def _batch_submit_vars(self): + vars = self.app_inst.variables + old_var_name = "_old_batch_submit" + if old_var_name in vars: + batch_submit_cmd = vars[old_var_name] + if "sbatch" not in batch_submit_cmd: + logger.warn( + "`sbatch` is missing in the given `batch_submit` command" + ) + else: + batch_submit_script = vars["batch_submit"] + batch_submit_cmd = f"sbatch {batch_submit_script}" + return { + "batch_submit_cmd": batch_submit_cmd, + } + register_template( name="batch_query", src_path="batch_query.tpl",