Skip to content

Commit

Permalink
Base driver kwargs on constructor signature (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
maddenp-noaa authored Jul 22, 2024
1 parent 33b448d commit 3eef7cf
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 32 deletions.
8 changes: 4 additions & 4 deletions src/uwtools/api/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from uwtools.utils.api import ensure_data_source


def execute( # pylint: disable=unused-argument
def execute(
module: Union[Path, str],
classname: str,
task: str,
schema_file: str,
config: Optional[Union[Path, str]] = None,
cycle: Optional[datetime] = None,
leadtime: Optional[timedelta] = None,
batch: Optional[bool] = False,
cycle: Optional[datetime] = None, # pylint: disable=unused-argument
leadtime: Optional[timedelta] = None, # pylint: disable=unused-argument
batch: Optional[bool] = False, # pylint: disable=unused-argument
dry_run: Optional[bool] = False,
graph_file: Optional[Union[Path, str]] = None,
key_path: Optional[list[str]] = None,
Expand Down
8 changes: 3 additions & 5 deletions src/uwtools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,16 +1073,14 @@ def _dispatch_to_driver(name: str, args: Args) -> bool:
kwargs = {
"task": args[STR.action],
"config": args[STR.cfgfile],
"batch": args[STR.batch],
"dry_run": args[STR.dryrun],
"graph_file": args[STR.graphfile],
"key_path": args[STR.keypath],
"stdin_ok": True,
}
if cycle := args.get(STR.cycle):
kwargs[STR.cycle] = cycle
if (leadtime := args.get(STR.leadtime)) is not None:
kwargs[STR.leadtime] = leadtime
for k in [STR.batch, STR.cycle, STR.leadtime]:
if k in args:
kwargs[k] = args.get(k)
return execute(**kwargs)


Expand Down
25 changes: 11 additions & 14 deletions src/uwtools/tests/utils/test_api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# pylint: disable=missing-function-docstring,protected-access,redefined-outer-name

import datetime as dt
import sys
from pathlib import Path
from unittest.mock import patch

from pytest import fixture, mark, raises

from uwtools.exceptions import UWError
from uwtools.tests.drivers.test_driver import ConcreteDriverCycleLeadtimeBased as TestDriverWCL
from uwtools.tests.drivers.test_driver import ConcreteDriverCycleLeadtimeBased as TestDriverCL
from uwtools.tests.drivers.test_driver import ConcreteDriverTimeInvariant as TestDriver
from uwtools.utils import api

Expand Down Expand Up @@ -105,16 +104,14 @@ def test_str2path_convert():
@mark.parametrize("hours", [0, 24, 168])
def test__execute(execute_kwargs, hours, tmp_path):
graph_file = tmp_path / "g.dot"
with patch.object(sys.modules[__name__], "TestDriverWCL", wraps=TestDriverWCL) as cd:
kwargs = {
**execute_kwargs,
"driver_class": cd,
"config": {"some": "config"},
"cycle": dt.datetime.now(),
"leadtime": dt.timedelta(hours=hours),
"graph_file": graph_file,
}
assert not graph_file.is_file()
assert api._execute(**kwargs) is True
assert cd.call_args.kwargs["leadtime"] == dt.timedelta(hours=hours)
kwargs = {
**execute_kwargs,
"driver_class": TestDriverCL,
"config": {"some": "config"},
"cycle": dt.datetime.now(),
"leadtime": dt.timedelta(hours=hours),
"graph_file": graph_file,
}
assert not graph_file.is_file()
assert api._execute(**kwargs) is True
assert graph_file.is_file()
18 changes: 9 additions & 9 deletions src/uwtools/utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import datetime as dt
import re
from inspect import getfullargspec
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union

Expand Down Expand Up @@ -147,10 +148,10 @@ def str2path(val: Any) -> Any:
def _execute(
driver_class: DriverT,
task: str,
cycle: Optional[dt.datetime] = None,
leadtime: Optional[dt.timedelta] = None,
config: Optional[Union[Path, str]] = None,
batch: bool = False,
cycle: Optional[dt.datetime] = None, # pylint: disable=unused-argument
leadtime: Optional[dt.timedelta] = None, # pylint: disable=unused-argument
batch: bool = False, # pylint: disable=unused-argument
dry_run: bool = False,
graph_file: Optional[Union[Path, str]] = None,
key_path: Optional[list[str]] = None,
Expand All @@ -164,9 +165,9 @@ def _execute(
:param driver_class: Class of driver object to instantiate.
:param task: The task to execute.
:param config: Path to config file (read stdin if missing or None).
:param cycle: The cycle.
:param leadtime: The leadtime.
:param config: Path to config file (read stdin if missing or None).
:param batch: Submit run to the batch system?
:param dry_run: Do not run the executable, just report what would have been done.
:param graph_file: Write Graphviz DOT output here.
Expand All @@ -176,14 +177,13 @@ def _execute(
"""
kwargs = dict(
config=ensure_data_source(str2path(config), stdin_ok),
batch=batch,
dry_run=dry_run,
key_path=key_path,
)
if cycle:
kwargs["cycle"] = cycle
if leadtime is not None:
kwargs["leadtime"] = leadtime
accepted = set(getfullargspec(driver_class).args)
for arg in ["batch", "cycle", "leadtime"]:
if arg in accepted:
kwargs[arg] = locals()[arg]
obj = driver_class(**kwargs)
getattr(obj, task)()
if graph_file:
Expand Down

0 comments on commit 3eef7cf

Please sign in to comment.