diff --git a/src/sirocco/workgraph.py b/src/sirocco/workgraph.py index af748137..bd374b5e 100644 --- a/src/sirocco/workgraph.py +++ b/src/sirocco/workgraph.py @@ -18,47 +18,51 @@ from sirocco.core import graph_items -# This is hack to aiida-workgraph, merging this into aiida-workgraph properly would require -# some major refactor see issue https://github.com/aiidateam/aiida-workgraph/issues/168 -# It might be better to give up on the graph like construction and just create the task -# directly with inputs, arguments and outputs -def _prepare_for_shell_task(task: dict, kwargs: dict) -> dict: - """Prepare the inputs for ShellTask""" - from aiida.common import lang - from aiida.orm import AbstractCode - from aiida_shell.launch import convert_nodes_single_file_data, prepare_code - - command = kwargs.pop("command", None) - resolve_command = kwargs.pop("resolve_command", False) - metadata = kwargs.pop("metadata", {}) - # setup code - if isinstance(command, str): - computer = (metadata or {}).get("options", {}).pop("computer", None) - code = prepare_code(command, computer, resolve_command) - else: - lang.type_check(command, AbstractCode) - code = command - # update the tasks with links - nodes = convert_nodes_single_file_data(kwargs.pop("nodes", {})) - # find all keys in kwargs start with "nodes." - for key in list(kwargs.keys()): - if key.startswith("nodes."): - nodes[key[6:]] = kwargs.pop(key) - metadata.update({"call_link_label": task["name"]}) - +# This is a workaround required when splitting the initialization of the task and its linked nodes Merging this into +# aiida-workgraph properly would require significant changes see issues +# https://github.com/aiidateam/aiida-workgraph/issues/168 The function is a copy of the original function in +# aiida-workgraph. The modifications are marked by comments. +def _prepare_for_shell_task(task: dict, inputs: dict) -> dict: + """Prepare the inputs for ShellJob""" + import inspect + + from aiida_shell.launch import prepare_shell_job_inputs + + # Retrieve the signature of `prepare_shell_job_inputs` to determine expected input parameters. + signature = inspect.signature(prepare_shell_job_inputs) + aiida_shell_input_keys = signature.parameters.keys() + + # Iterate over all WorkGraph `inputs`, and extract the ones which are expected by `prepare_shell_job_inputs` + inputs_aiida_shell_subset = {key: inputs[key] for key in inputs if key in aiida_shell_input_keys} + + try: + aiida_shell_inputs = prepare_shell_job_inputs(**inputs_aiida_shell_subset) + except ValueError: # noqa: TRY302 + raise + + # We need to remove the original input-keys, as they might be offending for the call to `launch_shell_job` + # E.g., `inputs` originally can contain `command`, which gets, however, transformed to # + # `code` by `prepare_shell_job_inputs` + for key in inputs_aiida_shell_subset: + inputs.pop(key) + + # Finally, we update the original `inputs` with the modified ones from the call to `prepare_shell_job_inputs` + inputs = {**inputs, **aiida_shell_inputs} + + inputs.setdefault("metadata", {}) + inputs["metadata"].update({"call_link_label": task["name"]}) + + # Workaround starts here + # This part is part of the workaround. We need to manually add the outputs from the task. + # Because kwargs are not populated with outputs default_outputs = {"remote_folder", "remote_stash", "retrieved", "_outputs", "_wait", "stdout", "stderr"} - task_outputs = {task["outputs"][i]["name"] for i in range(len(task["outputs"]))} - task_outputs = task_outputs.union(set(kwargs.pop("outputs", []))) + task_outputs = set(task["outputs"].keys()) + task_outputs = task_outputs.union(set(inputs.pop("outputs", []))) missing_outputs = task_outputs.difference(default_outputs) - return { - "code": code, - "nodes": nodes, - "filenames": kwargs.pop("filenames", {}), - "arguments": kwargs.pop("arguments", []), - "outputs": list(missing_outputs), - "parser": kwargs.pop("parser", None), - "metadata": metadata or {}, - } + inputs["outputs"] = list(missing_outputs) + # Workaround ends here + + return inputs aiida_workgraph.engine.utils.prepare_for_shell_task = _prepare_for_shell_task @@ -187,11 +191,12 @@ def _create_task_node(self, task: graph_items.Task): # NOTE: We don't pass the `nodes` dictionary here, as then we would need to have the sockets available when # we create the task. Instead, they are being updated via the WG internals when linking inputs/outputs to # tasks - workgraph_task = self._workgraph.tasks.new( + workgraph_task = self._workgraph.add_task( "ShellJob", name=label, command=command, arguments=[], + outputs=[], metadata={"options": {"prepend_text": prepend_text}}, ) @@ -218,20 +223,17 @@ def _link_input_nodes_to_task(self, task: graph_items.Task, input_: graph_items. task_label = AiidaWorkGraph.get_aiida_label_from_graph_item(task) input_label = AiidaWorkGraph.get_aiida_label_from_graph_item(input_) workgraph_task = self._aiida_task_nodes[task_label] - workgraph_task.inputs.new("Any", f"nodes.{input_label}") - workgraph_task.kwargs.append(f"nodes.{input_label}") + workgraph_task.add_input("workgraph.any", f"nodes.{input_label}") # resolve data if (data_node := self._aiida_data_nodes.get(input_label)) is not None: - if (nodes := workgraph_task.inputs.get("nodes")) is None: - msg = ( - f"Workgraph task {workgraph_task.name!r} did not initialize input nodes in the workgraph " - f"before linking. This is a bug in the code, please contact the developers by making an issue." - ) + if not hasattr(workgraph_task.inputs.nodes, f"{input_label}"): + msg = f"Socket {input_label!r} was not found in workgraph. Please contact a developer." raise ValueError(msg) - nodes.value.update({f"{input_label}": data_node}) + socket = getattr(workgraph_task.inputs.nodes, f"{input_label}") + socket.value = data_node elif (output_socket := self._aiida_socket_nodes.get(input_label)) is not None: - self._workgraph.links.new(output_socket, workgraph_task.inputs[f"nodes.{input_label}"]) + self._workgraph.add_link(output_socket, workgraph_task.inputs[f"nodes.{input_label}"]) else: msg = ( f"Input data node {input_label!r} was neither found in socket nodes nor in data nodes. The task " @@ -247,7 +249,7 @@ def _link_arguments_to_task(self, task: graph_items.Task): """ task_label = AiidaWorkGraph.get_aiida_label_from_graph_item(task) workgraph_task = self._aiida_task_nodes[task_label] - if (workgraph_task_arguments := workgraph_task.inputs.get("arguments")) is None: + if (workgraph_task_arguments := workgraph_task.inputs.arguments) is None: msg = ( f"Workgraph task {workgraph_task.name!r} did not initialize arguments nodes in the workgraph " f"before linking. This is a bug in the code, please contact developers." @@ -280,9 +282,10 @@ def _link_arguments_to_task(self, task: graph_items.Task): def _link_output_nodes_to_task(self, task: graph_items.Task, output: graph_items.Data): """Links the output to the workgraph task.""" + workgraph_task = self._aiida_task_nodes[AiidaWorkGraph.get_aiida_label_from_graph_item(task)] output_label = AiidaWorkGraph.get_aiida_label_from_graph_item(output) - output_socket = workgraph_task.outputs.new("Any", output.src) + output_socket = workgraph_task.add_output("workgraph.any", output.src) self._aiida_socket_nodes[output_label] = output_socket def run( diff --git a/tests/test_wc_workflow.py b/tests/test_wc_workflow.py index 315c8630..badf577d 100644 --- a/tests/test_wc_workflow.py +++ b/tests/test_wc_workflow.py @@ -85,4 +85,4 @@ def test_run_workgraph(config_path): core_workflow = Workflow.from_yaml(config_path) aiida_workflow = AiidaWorkGraph(core_workflow) out = aiida_workflow.run() - assert out.get("execution_count", None).value == 0 # TODO: should be 1 but we need to update workgraph for this + assert out.get("execution_count", None).value == 1