Skip to content

Commit

Permalink
Upgrade to aiida-workgraph v0.4.10
Browse files Browse the repository at this point in the history
To include several bugfixes we upgrade to the version v0.4.10. This is
refuired to specify computer in the metadata of the workgraph task as
this is not possible with the currently pinned version. The workgraph
API has changed therefore we adapt to new API the functions which makes
part of the function calls cleaner. The workaround is still needed but
has been better documented.
  • Loading branch information
agoscinski committed Jan 9, 2025
1 parent 8690f48 commit 998d64a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 52 deletions.
105 changes: 54 additions & 51 deletions src/sirocco/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,51 @@
from sirocco.core import graph_items


# This is hack to aiida-workgraph, merging this into aiida-workgraph properly would require
# some major refactor see issue https://github.com/aiidateam/aiida-workgraph/issues/168
# It might be better to give up on the graph like construction and just create the task
# directly with inputs, arguments and outputs
def _prepare_for_shell_task(task: dict, kwargs: dict) -> dict:
"""Prepare the inputs for ShellTask"""
from aiida.common import lang
from aiida.orm import AbstractCode
from aiida_shell.launch import convert_nodes_single_file_data, prepare_code

command = kwargs.pop("command", None)
resolve_command = kwargs.pop("resolve_command", False)
metadata = kwargs.pop("metadata", {})
# setup code
if isinstance(command, str):
computer = (metadata or {}).get("options", {}).pop("computer", None)
code = prepare_code(command, computer, resolve_command)
else:
lang.type_check(command, AbstractCode)
code = command
# update the tasks with links
nodes = convert_nodes_single_file_data(kwargs.pop("nodes", {}))
# find all keys in kwargs start with "nodes."
for key in list(kwargs.keys()):
if key.startswith("nodes."):
nodes[key[6:]] = kwargs.pop(key)
metadata.update({"call_link_label": task["name"]})

# This is a workaround required when splitting the initialization of the task and its linked nodes Merging this into
# aiida-workgraph properly would require significant changes see issues
# https://github.com/aiidateam/aiida-workgraph/issues/168 The function is a copy of the original function in
# aiida-workgraph. The modifications are marked by comments.
def _prepare_for_shell_task(task: dict, inputs: dict) -> dict:
"""Prepare the inputs for ShellJob"""
import inspect

from aiida_shell.launch import prepare_shell_job_inputs

# Retrieve the signature of `prepare_shell_job_inputs` to determine expected input parameters.
signature = inspect.signature(prepare_shell_job_inputs)
aiida_shell_input_keys = signature.parameters.keys()

# Iterate over all WorkGraph `inputs`, and extract the ones which are expected by `prepare_shell_job_inputs`
inputs_aiida_shell_subset = {key: inputs[key] for key in inputs if key in aiida_shell_input_keys}

try:
aiida_shell_inputs = prepare_shell_job_inputs(**inputs_aiida_shell_subset)
except ValueError: # noqa: TRY302
raise

# We need to remove the original input-keys, as they might be offending for the call to `launch_shell_job`
# E.g., `inputs` originally can contain `command`, which gets, however, transformed to #
# `code` by `prepare_shell_job_inputs`
for key in inputs_aiida_shell_subset:
inputs.pop(key)

# Finally, we update the original `inputs` with the modified ones from the call to `prepare_shell_job_inputs`
inputs = {**inputs, **aiida_shell_inputs}

inputs.setdefault("metadata", {})
inputs["metadata"].update({"call_link_label": task["name"]})

# Workaround starts here
# This part is part of the workaround. We need to manually add the outputs from the task.
# Because kwargs are not populated with outputs
default_outputs = {"remote_folder", "remote_stash", "retrieved", "_outputs", "_wait", "stdout", "stderr"}
task_outputs = {task["outputs"][i]["name"] for i in range(len(task["outputs"]))}
task_outputs = task_outputs.union(set(kwargs.pop("outputs", [])))
task_outputs = set(task["outputs"].keys())
task_outputs = task_outputs.union(set(inputs.pop("outputs", [])))
missing_outputs = task_outputs.difference(default_outputs)
return {
"code": code,
"nodes": nodes,
"filenames": kwargs.pop("filenames", {}),
"arguments": kwargs.pop("arguments", []),
"outputs": list(missing_outputs),
"parser": kwargs.pop("parser", None),
"metadata": metadata or {},
}
inputs["outputs"] = list(missing_outputs)
# Workaround ends here

return inputs


aiida_workgraph.engine.utils.prepare_for_shell_task = _prepare_for_shell_task
Expand Down Expand Up @@ -187,11 +191,12 @@ def _create_task_node(self, task: graph_items.Task):
# NOTE: We don't pass the `nodes` dictionary here, as then we would need to have the sockets available when
# we create the task. Instead, they are being updated via the WG internals when linking inputs/outputs to
# tasks
workgraph_task = self._workgraph.tasks.new(
workgraph_task = self._workgraph.add_task(
"ShellJob",
name=label,
command=command,
arguments=[],
outputs=[],
metadata={"options": {"prepend_text": prepend_text}},
)

Expand All @@ -218,20 +223,17 @@ def _link_input_nodes_to_task(self, task: graph_items.Task, input_: graph_items.
task_label = AiidaWorkGraph.get_aiida_label_from_graph_item(task)
input_label = AiidaWorkGraph.get_aiida_label_from_graph_item(input_)
workgraph_task = self._aiida_task_nodes[task_label]
workgraph_task.inputs.new("Any", f"nodes.{input_label}")
workgraph_task.kwargs.append(f"nodes.{input_label}")
workgraph_task.add_input("workgraph.any", f"nodes.{input_label}")

# resolve data
if (data_node := self._aiida_data_nodes.get(input_label)) is not None:
if (nodes := workgraph_task.inputs.get("nodes")) is None:
msg = (
f"Workgraph task {workgraph_task.name!r} did not initialize input nodes in the workgraph "
f"before linking. This is a bug in the code, please contact the developers by making an issue."
)
if not hasattr(workgraph_task.inputs.nodes, f"{input_label}"):
msg = f"Socket {input_label!r} was not found in workgraph. Please contact a developer."
raise ValueError(msg)
nodes.value.update({f"{input_label}": data_node})
socket = getattr(workgraph_task.inputs.nodes, f"{input_label}")
socket.value = data_node
elif (output_socket := self._aiida_socket_nodes.get(input_label)) is not None:
self._workgraph.links.new(output_socket, workgraph_task.inputs[f"nodes.{input_label}"])
self._workgraph.add_link(output_socket, workgraph_task.inputs[f"nodes.{input_label}"])
else:
msg = (
f"Input data node {input_label!r} was neither found in socket nodes nor in data nodes. The task "
Expand All @@ -247,7 +249,7 @@ def _link_arguments_to_task(self, task: graph_items.Task):
"""
task_label = AiidaWorkGraph.get_aiida_label_from_graph_item(task)
workgraph_task = self._aiida_task_nodes[task_label]
if (workgraph_task_arguments := workgraph_task.inputs.get("arguments")) is None:
if (workgraph_task_arguments := workgraph_task.inputs.arguments) is None:
msg = (
f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph "
f"before linking. This is a bug in the code, please contact developers."
Expand Down Expand Up @@ -280,9 +282,10 @@ def _link_arguments_to_task(self, task: graph_items.Task):

def _link_output_nodes_to_task(self, task: graph_items.Task, output: graph_items.Data):
"""Links the output to the workgraph task."""

workgraph_task = self._aiida_task_nodes[AiidaWorkGraph.get_aiida_label_from_graph_item(task)]
output_label = AiidaWorkGraph.get_aiida_label_from_graph_item(output)
output_socket = workgraph_task.outputs.new("Any", output.src)
output_socket = workgraph_task.add_output("workgraph.any", output.src)
self._aiida_socket_nodes[output_label] = output_socket

def run(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_wc_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ def test_run_workgraph(config_path):
core_workflow = Workflow.from_yaml(config_path)
aiida_workflow = AiidaWorkGraph(core_workflow)
out = aiida_workflow.run()
assert out.get("execution_count", None).value == 0 # TODO: should be 1 but we need to update workgraph for this
assert out.get("execution_count", None).value == 1

0 comments on commit 998d64a

Please sign in to comment.