Skip to content

Commit

Permalink
Merge pull request #833 from linsword13/render-dict
Browse files Browse the repository at this point in the history
Add in a hook for supplying variables for template render
  • Loading branch information
douglasjacobsen authored Jan 17, 2025
2 parents ed45eb8 + eba4dae commit 73cbec1
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
9 changes: 6 additions & 3 deletions lib/ramble/ramble/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,12 +2324,15 @@ def _get_template_config(obj, tpl_config, obj_type):

def _render_object_templates(self, extra_vars, workspace):
for obj, tpl_config in self._object_templates(workspace):
extra_vars = extra_vars.copy()
if callable(getattr(obj, "template_render_vars", None)):
extra_vars.update(obj.template_render_vars())
src_path = tpl_config["src_path"]
with open(src_path) as f_in:
content = f_in.read()
extra_vars_wm = tpl_config.get("extra_vars")
if extra_vars_wm is not None:
extra_vars.update(extra_vars_wm)
extra_vars_dict = tpl_config.get("extra_vars")
if extra_vars_dict is not None:
extra_vars.update(extra_vars_dict)
extra_vars_func_name = tpl_config.get("extra_vars_func_name")
if extra_vars_func_name is not None:
extra_vars_func = getattr(obj, extra_vars_func_name)
Expand Down
4 changes: 4 additions & 0 deletions lib/ramble/ramble/workflow_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def conditional_expand(self, templates):
continue
return expanded

def template_render_vars(self):
"""Define variables to be used in template rendering"""
return {"workflow_pragmas": "", "workflow_hostfile_cmd": ""}

def copy(self):
"""Deep copy a workflow manager instance"""
new_copy = type(self)(self._file_path)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash
{sbatch_headers_str}
{workflow_pragmas}

cd {experiment_run_dir}

scontrol show hostnames > {experiment_run_dir}/hostfile
{workflow_hostfile_cmd}

{command}
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def __init__(self, file_path):
name="slurm_execute_experiment",
src_name="slurm_execute_experiment.tpl",
dest_path="slurm_execute_experiment",
extra_vars_func="execute_vars",
)

def _execute_vars(self):
def template_render_vars(self):
vars = super().template_render_vars()
expander = self.app_inst.expander
# Adding pre-defined and custom headers
pragmas = [
Expand All @@ -135,7 +135,11 @@ def _execute_vars(self):
)
pragmas = pragmas + extra_headers
header_str = "\n".join(self.conditional_expand(pragmas))
return {"sbatch_headers_str": header_str}
return {
**vars,
"workflow_pragmas": header_str,
"workflow_hostfile_cmd": self.runner.get_hostfile_cmd(),
}

def _check_partition(self, partition):
"""Warns about potential issues of the slurm_partition config
Expand Down Expand Up @@ -246,3 +250,6 @@ def get_partitions(self):
"default_partition": default_partition,
"partitions": partitions,
}

def get_hostfile_cmd(self):
return "scontrol show hostnames > {experiment_run_dir}/hostfile"

0 comments on commit 73cbec1

Please sign in to comment.