Skip to content

Commit

Permalink
infra: Static typing on CI (#118)
Browse files Browse the repository at this point in the history
* fix all typing errors

* add mypy to CI

* set static typing env to always use python 3.12 for consistency
  • Loading branch information
DropD authored Feb 4, 2025
1 parent 2064daa commit 0dfaea5
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 116 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,21 @@ jobs:
verdi presto
- name: Run formatter and linter
run: hatch fmt --check

typechecking:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: 3.12
- name: Install hatch
run: |
pip install --upgrade pip
pip install hatch
- name: Install Graphviz
run: sudo apt-get install graphviz graphviz-dev
- name: Run formatter and linter
run: hatch run types:check
27 changes: 27 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,30 @@ serve = [
deploy = [
"mkdocs gh-deploy --no-history -f docs/mkdocs.yml"
]

[tool.hatch.envs.types]
python = "3.12"
extra-dependencies = [
"mypy>=1.0.0",
"pytest",
"lxml-stubs",
"types-setuptools",
"types-docutils",
"types-colorama",
"types-Pygments"
]

[tool.hatch.envs.types.scripts]
check = "mypy --no-incremental {args:.}"

[[tool.mypy.overrides]]
module = ["isoduration", "isoduration.*"]
follow_untyped_imports = true

[[tool.mypy.overrides]]
module = ["pygraphviz"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["f90nml"]
ignore_missing_imports = true
3 changes: 2 additions & 1 deletion src/sirocco/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._tasks import IconTask, ShellTask
from .graph_items import Cycle, Data, GraphItem, Task
from .workflow import Workflow

__all__ = ["Workflow", "GraphItem", "Data", "Task", "Cycle"]
__all__ = ["Workflow", "GraphItem", "Data", "Task", "Cycle", "ShellTask", "IconTask"]
5 changes: 3 additions & 2 deletions src/sirocco/core/_tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import icon_task, shell_task
from .icon_task import IconTask
from .shell_task import ShellTask

__all__ = ["icon_task", "shell_task"]
__all__ = ["IconTask", "ShellTask"]
66 changes: 36 additions & 30 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass, field
from itertools import chain, product
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypeVar, cast

from sirocco.parsing._yaml_data_models import (
ConfigAvailableData,
Expand All @@ -15,9 +15,12 @@
from datetime import datetime
from pathlib import Path

from termcolor._types import Color

from sirocco.parsing._yaml_data_models import (
ConfigBaseData,
ConfigCycleTask,
ConfigCycleTaskWaitOn,
ConfigTask,
TargetNodesBaseModel,
)
Expand All @@ -27,17 +30,20 @@
class GraphItem:
"""base class for Data Tasks and Cycles"""

color: ClassVar[str]
color: ClassVar[Color]

name: str
coordinates: dict


GRAPH_ITEM_T = TypeVar("GRAPH_ITEM_T", bound=GraphItem)


@dataclass(kw_only=True)
class Data(ConfigBaseDataSpecs, GraphItem):
"""Internal representation of a data node"""

color: ClassVar[str] = field(default="light_blue", repr=False)
color: ClassVar[Color] = field(default="light_blue", repr=False)

available: bool

Expand All @@ -61,15 +67,17 @@ class Task(ConfigBaseTaskSpecs, GraphItem):
"""Internal representation of a task node"""

plugin_classes: ClassVar[dict[str, type]] = field(default={}, repr=False)
color: ClassVar[str] = field(default="light_red", repr=False)
color: ClassVar[Color] = field(default="light_red", repr=False)

inputs: list[BoundData] = field(default_factory=list)
outputs: list[Data] = field(default_factory=list)
wait_on: list[Task] = field(default_factory=list)
config_rootdir: Path | None = None
config_rootdir: Path
start_date: datetime | None = None
end_date: datetime | None = None

_wait_on_specs: list[ConfigCycleTaskWaitOn] = field(default_factory=list, repr=False)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.plugin in Task.plugin_classes:
Expand Down Expand Up @@ -118,7 +126,7 @@ def from_config(

return new

def link_wait_on_tasks(self, taskstore: Store):
def link_wait_on_tasks(self, taskstore: Store[Task]) -> None:
self.wait_on = list(
chain(
*(
Expand All @@ -133,24 +141,24 @@ def link_wait_on_tasks(self, taskstore: Store):
class Cycle(GraphItem):
"""Internal reprenstation of a cycle"""

color: ClassVar[str] = field(default="light_green", repr=False)
color: ClassVar[Color] = field(default="light_green", repr=False)

tasks: list[Task]


class Array:
"""Dictionnary of GraphItem objects accessed by arbitrary dimensions"""
class Array[GRAPH_ITEM_T]:
"""Dictionnary of GRAPH_ITEM_T objects accessed by arbitrary dimensions"""

def __init__(self, name: str) -> None:
self._name = name
self._dims: tuple[str] | None = None
self._axes: dict | None = None
self._dict: dict[tuple, GraphItem] | None = None
self._dims: tuple[str, ...] = ()
self._axes: dict[str, set] = {}
self._dict: dict[tuple, GRAPH_ITEM_T] = {}

def __setitem__(self, coordinates: dict, value: GraphItem) -> None:
def __setitem__(self, coordinates: dict, value: GRAPH_ITEM_T) -> None:
# First access: set axes and initialize dictionnary
input_dims = tuple(coordinates.keys())
if self._dims is None:
if self._dims == ():
self._dims = input_dims
self._axes = {k: set() for k in self._dims}
self._dict = {}
Expand All @@ -171,15 +179,15 @@ def __setitem__(self, coordinates: dict, value: GraphItem) -> None:
# Set item
self._dict[key] = value

def __getitem__(self, coordinates: dict) -> GraphItem:
def __getitem__(self, coordinates: dict) -> GRAPH_ITEM_T:
if self._dims != (input_dims := tuple(coordinates.keys())):
msg = f"Array {self._name}: coordinate names {input_dims} don't match Array dimensions {self._dims}"
raise KeyError(msg)
# use the order of self._dims instead of param_keys to ensure reproducibility
key = tuple(coordinates[dim] for dim in self._dims)
return self._dict[key]

def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GraphItem]:
def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GRAPH_ITEM_T]:
# Check date references
if "date" not in self._dims and (spec.lag or spec.date):
msg = f"Array {self._name} has no date dimension, cannot be referenced by dates"
Expand All @@ -205,26 +213,24 @@ def _resolve_target_dim(self, spec: TargetNodesBaseModel, dim: str, reference: A
else:
yield from self._axes[dim]

def __iter__(self) -> Iterator[GraphItem]:
def __iter__(self) -> Iterator[GRAPH_ITEM_T]:
yield from self._dict.values()


class Store:
"""Container for GraphItem Arrays"""
class Store[GRAPH_ITEM_T]:
"""Container for GRAPH_ITEM_T Arrays"""

def __init__(self):
self._dict: dict[str, Array] = {}
def __init__(self) -> None:
self._dict: dict[str, Array[GRAPH_ITEM_T]] = {}

def add(self, item) -> None:
if not isinstance(item, GraphItem):
msg = "items in a Store must be of instance GraphItem"
raise TypeError(msg)
name, coordinates = item.name, item.coordinates
def add(self, item: GRAPH_ITEM_T) -> None:
graph_item = cast(GraphItem, item) # mypy can somehow not deduce this
name, coordinates = graph_item.name, graph_item.coordinates
if name not in self._dict:
self._dict[name] = Array(name)
self._dict[name] = Array[GRAPH_ITEM_T](name)
self._dict[name][coordinates] = item

def __getitem__(self, key: tuple[str, dict]) -> GraphItem:
def __getitem__(self, key: tuple[str, dict]) -> GRAPH_ITEM_T:
name, coordinates = key
if "date" in coordinates and coordinates["date"] is None:
del coordinates["date"]
Expand All @@ -233,7 +239,7 @@ def __getitem__(self, key: tuple[str, dict]) -> GraphItem:
raise KeyError(msg)
return self._dict[name][coordinates]

def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GraphItem]:
def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GRAPH_ITEM_T]:
# Check if target items should be querried at all
if (when := spec.when) is not None:
if (ref_date := reference.get("date")) is None:
Expand All @@ -248,5 +254,5 @@ def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> I
# Yield items
yield from self._dict[spec.name].iter_from_cycle_spec(spec, reference)

def __iter__(self) -> Iterator[GraphItem]:
def __iter__(self) -> Iterator[GRAPH_ITEM_T]:
yield from chain(*(self._dict.values()))
23 changes: 10 additions & 13 deletions src/sirocco/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sirocco.core.graph_items import Cycle, Data, Store, Task
from sirocco.parsing._yaml_data_models import (
ConfigBaseData,
ConfigWorkflow,
)

Expand All @@ -14,10 +15,8 @@
from pathlib import Path

from sirocco.parsing._yaml_data_models import (
ConfigAvailableData,
ConfigCycle,
ConfigData,
ConfigGeneratedData,
ConfigTask,
)

Expand All @@ -37,13 +36,11 @@ def __init__(
self.name: str = name
self.config_rootdir: Path = config_rootdir

self.tasks: Store = Store()
self.data: Store = Store()
self.cycles: Store = Store()
self.tasks: Store[Task] = Store()
self.data: Store[Data] = Store()
self.cycles: Store[Cycle] = Store()

data_dict: dict[str, ConfigAvailableData | ConfigGeneratedData] = {
data.name: data for data in chain(data.available, data.generated)
}
data_dict: dict[str, ConfigBaseData] = {data.name: data for data in chain(data.available, data.generated)}
task_dict: dict[str, ConfigTask] = {task.name: task for task in tasks}

# Function to iterate over date and parameter combinations
Expand All @@ -52,9 +49,9 @@ def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator
yield from (dict(zip(space.keys(), x, strict=False)) for x in product(*space.values()))

# 1 - create availalbe data nodes
for data_config in data.available:
for coordinates in iter_coordinates(param_refs=data_config.parameters, date=None):
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))
for available_data_config in data.available:
for coordinates in iter_coordinates(param_refs=available_data_config.parameters, date=None):
self.data.add(Data.from_config(config=available_data_config, coordinates=coordinates))

# 2 - create output data nodes
for cycle_config in cycles:
Expand Down Expand Up @@ -100,9 +97,9 @@ def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator
task.link_wait_on_tasks(self.tasks)

@staticmethod
def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime]:
def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime | None]:
yield (date := cycle_config.start_date)
if cycle_config.period is not None:
if cycle_config.period is not None and date is not None and cycle_config.end_date is not None:
while (date := date + cycle_config.period) < cycle_config.end_date:
yield date

Expand Down
Loading

0 comments on commit 0dfaea5

Please sign in to comment.