Skip to content

Commit

Permalink
Moves logic from CanonicalWorkflow to core.Workflow and `ConfigWo…
Browse files Browse the repository at this point in the history
…rkflow`

The validation of `cycles` and `tasks` in `list_not_empty` is moved to the
`ConfigWorkflow`. This subsequently required a modification of the tests that
initalize a `ConfigWorkflow` with empty `cycles` and `tasks` as this is not
anymore allowed through the tests. We switched to `BeforeValidator`
as the check can be happen before pydantic validation.

The `data_dict` and `task_dict` member variables are moved to the
constructor of the `core.Workflow`. The constructor of `core.Workflow`
is split into two: `from_config_workflow` constructor, which replicates
the behavir of the previous constructor and the default construcor now
accepts all required that are extracted from the passed `ConfigWorkflow`
in the previous constructor as individual parameters.

The introduction of the `from_config_file` constructor in ConfigWorkflow, which
replicates the behavior of `load_workflow_config`, allows the `rootdir` to
become a part of `ConfigWorkflow`. This approach eliminates the need for an
external utility function for workflow creation from a file. Since the utility
function was essentially acting as a constructor for `ConfigWorkflow`, it
simplifies the interface by enabling direct access to the `rootdir` within
`core.Workflow`. In addition we make `rootdir` and `name` not optional as these
parameters cane determined in the new `from_config_file` and passed to the
default constructor. Furthermore, We introduce the util function
`validate_yaml_content` that allows a more generic usage of creating an instance
from a `Conig*` class from a yaml string, especially used in the docstring
tests.
  • Loading branch information
agoscinski committed Jan 31, 2025
1 parent c57a890 commit f913fa4
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 137 deletions.
65 changes: 50 additions & 15 deletions src/sirocco/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,79 @@
from __future__ import annotations

from itertools import product
from itertools import chain, product
from typing import TYPE_CHECKING, Self

from sirocco.core.graph_items import Cycle, Data, Store, Task
from sirocco.parsing._yaml_data_models import (
CanonicalWorkflow,
load_workflow_config,
ConfigWorkflow,
)

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path

from sirocco.parsing._yaml_data_models import ConfigCycle
from sirocco.parsing._yaml_data_models import (
ConfigAvailableData,
ConfigCycle,
ConfigData,
ConfigGeneratedData,
ConfigTask,
)


class Workflow:
"""Internal representation of a workflow"""

def __init__(self, workflow_config: CanonicalWorkflow) -> None:
self.name: str = workflow_config.name
self.config_rootdir: Path = workflow_config.rootdir
def __init__(
self,
name: str,
config_rootdir: Path,
cycles: list[ConfigCycle],
tasks: list[ConfigTask],
data: ConfigData,
parameters: dict[str, list],
) -> None:
self.name: str = name
self.config_rootdir: Path = config_rootdir

self.tasks: Store = Store()
self.data: Store = Store()
self.cycles: Store = Store()

data_dict: dict[str, ConfigAvailableData | ConfigGeneratedData] = {
data.name: data for data in chain(data.available, data.generated)
}
task_dict: dict[str, ConfigTask] = {task.name: task for task in tasks}

# Function to iterate over date and parameter combinations
def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator[dict]:
space = ({} if date is None else {"date": [date]}) | {k: workflow_config.parameters[k] for k in param_refs}
space = ({} if date is None else {"date": [date]}) | {k: parameters[k] for k in param_refs}
yield from (dict(zip(space.keys(), x, strict=False)) for x in product(*space.values()))

# 1 - create availalbe data nodes
for data_config in workflow_config.data.available:
for data_config in data.available:
for coordinates in iter_coordinates(param_refs=data_config.parameters, date=None):
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))

# 2 - create output data nodes
for cycle_config in workflow_config.cycles:
for cycle_config in cycles:
for date in self.cycle_dates(cycle_config):
for task_ref in cycle_config.tasks:
for data_ref in task_ref.outputs:
data_name = data_ref.name
data_config = workflow_config.data_dict[data_name]
data_config = data_dict[data_name]
for coordinates in iter_coordinates(param_refs=data_config.parameters, date=date):
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))

# 3 - create cycles and tasks
for cycle_config in workflow_config.cycles:
for cycle_config in cycles:
cycle_name = cycle_config.name
for date in self.cycle_dates(cycle_config):
cycle_tasks = []
for task_graph_spec in cycle_config.tasks:
task_name = task_graph_spec.name
task_config = workflow_config.task_dict[task_name]
task_config = task_dict[task_name]

for coordinates in iter_coordinates(param_refs=task_config.parameters, date=date):
task = Task.from_config(
Expand Down Expand Up @@ -88,5 +107,21 @@ def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime]:
yield date

@classmethod
def from_yaml(cls: type[Self], config_path: str) -> Self:
return cls(load_workflow_config(config_path))
def from_config_file(cls: type[Self], config_path: str) -> Self:
"""
Loads a python representation of a workflow config file.
:param config_path: the string to the config yaml file containing the workflow definition
"""
return cls.from_config_workflow(ConfigWorkflow.from_config_file(config_path))

@classmethod
def from_config_workflow(cls: type[Self], config_workflow: ConfigWorkflow) -> Workflow:
return cls(
name=config_workflow.name,
config_rootdir=config_workflow.rootdir,
cycles=config_workflow.cycles,
tasks=config_workflow.tasks,
data=config_workflow.data,
parameters=config_workflow.parameters,
)
4 changes: 2 additions & 2 deletions src/sirocco/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ._yaml_data_models import (
load_workflow_config,
ConfigWorkflow,
)

__all__ = [
"load_workflow_config",
"ConfigWorkflow",
]
126 changes: 56 additions & 70 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime
from io import StringIO
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal
from typing import Annotated, Any, ClassVar, Literal, Self

from isoduration import parse_duration
from isoduration.types import Duration # pydantic needs type # noqa: TCH002
Expand All @@ -27,6 +27,15 @@

from sirocco.parsing._utils import TimeUtils

ITEM_T = typing.TypeVar("ITEM_T")


def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]:
if len(value) < 1:
msg = "At least one element is required."
raise ValueError(msg)
return value


class _NamedBaseModel(BaseModel):
"""
Expand Down Expand Up @@ -625,8 +634,10 @@ class ConfigWorkflow(BaseModel):
minimal yaml to generate:
>>> import textwrap
>>> config = textwrap.dedent(
>>> content = textwrap.dedent(
... '''
... name: minimal
... rootdir: /location/of/config/file
... cycles:
... - minimal_cycle:
... tasks:
Expand All @@ -637,25 +648,40 @@ class ConfigWorkflow(BaseModel):
... data:
... available:
... - foo:
... type: "file"
... src: "foo.txt"
... type: file
... src: foo.txt
... generated:
... - bar:
... type: "file"
... src: some_task_output
... type: dir
... src: bar
... '''
... )
>>> wf = validate_yaml_content(ConfigWorkflow, config)
>>> wf = validate_yaml_content(ConfigWorkflow, content)
minimum programmatically created instance
>>> empty_wf = ConfigWorkflow(cycles=[], tasks=[], data={})
>>> wf = ConfigWorkflow(
... name="minimal",
... rootdir=Path("/location/of/config/file"),
... cycles=[ConfigCycle(minimal_cycle={"tasks": [ConfigCycleTask(task_a={})]})],
... tasks=[ConfigShellTask(task_b={"plugin": "shell"})],
... data=ConfigData(
... available=[
... ConfigAvailableData(name="foo", type=DataType.FILE, src="foo.txt")
... ],
... generated=[
... ConfigGeneratedData(name="bar", type=DataType.DIR, src="bar")
... ],
... ),
... parameters={},
... )
"""

name: str | None = None
cycles: list[ConfigCycle]
tasks: list[ConfigTask]
rootdir: Path
name: str
cycles: Annotated[list[ConfigCycle], AfterValidator(list_not_empty)]
tasks: Annotated[list[ConfigTask], AfterValidator(list_not_empty)]
data: ConfigData
parameters: dict[str, list] = {}

Expand All @@ -682,67 +708,27 @@ def check_parameters(self) -> ConfigWorkflow:
raise ValueError(msg)
return self

@classmethod
def from_config_file(cls, config_path: str) -> Self:
"""Creates a ConfigWorkflow instance from a config file, a yaml with the workflow definition.
ITEM_T = typing.TypeVar("ITEM_T")


def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]:
if len(value) < 1:
msg = "At least one element is required."
raise ValueError(msg)
return value


class CanonicalWorkflow(BaseModel):
name: str
rootdir: Path
cycles: Annotated[list[ConfigCycle], AfterValidator(list_not_empty)]
tasks: Annotated[list[ConfigTask], AfterValidator(list_not_empty)]
data: ConfigData
parameters: dict[str, list[Any]]

@property
def data_dict(self) -> dict[str, ConfigAvailableData | ConfigGeneratedData]:
return {data.name: data for data in itertools.chain(self.data.available, self.data.generated)}

@property
def task_dict(self) -> dict[str, ConfigTask]:
return {task.name: task for task in self.tasks}


def canonicalize_workflow(config_workflow: ConfigWorkflow, rootdir: Path) -> CanonicalWorkflow:
if not config_workflow.name:
msg = "Workflow name required for canonicalization."
raise ValueError(msg)
return CanonicalWorkflow(
name=config_workflow.name,
rootdir=rootdir,
cycles=config_workflow.cycles,
tasks=config_workflow.tasks,
data=config_workflow.data,
parameters=config_workflow.parameters,
)


def load_workflow_config(workflow_config: str) -> CanonicalWorkflow:
"""
Loads a python representation of a workflow config file.
:param workflow_config: the string to the config yaml file containing the workflow definition
"""
config_path = Path(workflow_config)

content = config_path.read_text()

parsed_workflow = validate_yaml_content(ConfigWorkflow, content)

# If name was not specified, then we use filename without file extension
if parsed_workflow.name is None:
parsed_workflow.name = config_path.stem

rootdir = config_path.resolve().parent
Args:
config_path (str): The path of the config file to load from.
return canonicalize_workflow(config_workflow=parsed_workflow, rootdir=rootdir)
Returns:
OBJECT_T: An instance of the specified class type with data parsed and
validated from the YAML content.
"""
config_path_ = Path(config_path)
content = config_path_.read_text()
reader = YAML(typ="safe", pure=True)
object_ = reader.load(StringIO(content))
# If name was not specified, then we use filename without file extension
if "name" not in object_:
object_["name"] = config_path_.stem
object_["rootdir"] = config_path_.resolve().parent
adapter = TypeAdapter(cls)
return adapter.validate_python(object_)


OBJECT_T = typing.TypeVar("OBJECT_T")
Expand Down
4 changes: 2 additions & 2 deletions src/sirocco/vizgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,5 @@ def from_core_workflow(cls, workflow: Workflow):
return cls(workflow.name, workflow.cycles, workflow.data)

@classmethod
def from_yaml(cls, config_path: str):
return cls.from_core_workflow(Workflow.from_yaml(config_path))
def from_config_file(cls, config_path: str):
return cls.from_core_workflow(Workflow.from_config_file(config_path))
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +1,22 @@
import pathlib

import pytest

from sirocco.parsing import _yaml_data_models as models

pytest_plugins = ["aiida.tools.pytest_fixtures"]


@pytest.fixture(scope="session")
def minimal_config() -> models.ConfigWorkflow:
return models.ConfigWorkflow(
name="minimal",
rootdir=pathlib.Path("minimal"),
cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(some_task={})]})],
tasks=[models.ConfigShellTask(some_task={"plugin": "shell"})],
data=models.ConfigData(
available=[models.ConfigAvailableData(name="foo", type=models.DataType.FILE, src="foo.txt")],
generated=[models.ConfigGeneratedData(name="bar", type=models.DataType.DIR, src="bar")],
),
parameters={},
)
12 changes: 6 additions & 6 deletions tests/test_wc_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def config_paths(request):

def test_parse_config_file(config_paths, pprinter):
reference_str = config_paths["txt"].read_text()
test_str = pprinter.format(Workflow.from_yaml(config_paths["yml"]))
test_str = pprinter.format(Workflow.from_config_file(config_paths["yml"]))
if test_str != reference_str:
new_path = Path(config_paths["txt"]).with_suffix(".new.txt")
new_path.write_text(test_str)
Expand All @@ -61,11 +61,11 @@ def test_parse_config_file(config_paths, pprinter):

@pytest.mark.skip(reason="don't run it each time, uncomment to regenerate serilaized data")
def test_serialize_workflow(config_paths, pprinter):
config_paths["txt"].write_text(pprinter.format(Workflow.from_yaml(config_paths["yml"])))
config_paths["txt"].write_text(pprinter.format(Workflow.from_config_file(config_paths["yml"])))


def test_vizgraph(config_paths):
VizGraph.from_yaml(config_paths["yml"]).draw(file_path=config_paths["svg"])
VizGraph.from_config_file(config_paths["yml"]).draw(file_path=config_paths["svg"])


# configs that are tested for running workgraph
Expand All @@ -85,7 +85,7 @@ def test_run_workgraph(config_path, aiida_computer):
# some configs reference computer "localhost" which we need to create beforehand
aiida_computer("localhost").store()

core_workflow = Workflow.from_yaml(config_path)
core_workflow = Workflow.from_config_file(config_path)
aiida_workflow = AiidaWorkGraph(core_workflow)
out = aiida_workflow.run()
assert out.get("execution_count", None).value == 1
Expand All @@ -98,7 +98,7 @@ def test_run_workgraph(config_path, aiida_computer):
)
def test_nml_mod(config_paths, tmp_path):
nml_refdir = config_paths["txt"].parent / "ICON_namelists"
wf = Workflow.from_yaml(config_paths["yml"])
wf = Workflow.from_config_file(config_paths["yml"])
# Create core mamelists
for task in wf.tasks:
if isinstance(task, IconTask):
Expand All @@ -121,7 +121,7 @@ def test_nml_mod(config_paths, tmp_path):
)
def test_serialize_nml(config_paths):
nml_refdir = config_paths["txt"].parent / "ICON_namelists"
wf = Workflow.from_yaml(config_paths["yml"])
wf = Workflow.from_config_file(config_paths["yml"])
for task in wf.tasks:
if isinstance(task, IconTask):
task.create_workflow_namelists(folder=nml_refdir)
Loading

0 comments on commit f913fa4

Please sign in to comment.