diff --git a/docs/sections/user_guide/api/driver.rst b/docs/sections/user_guide/api/driver.rst new file mode 100644 index 000000000..46f0709ce --- /dev/null +++ b/docs/sections/user_guide/api/driver.rst @@ -0,0 +1,6 @@ +``uwtools.api.driver`` +====================== + +.. automodule:: uwtools.api.driver + :inherited-members: + :members: diff --git a/docs/sections/user_guide/api/index.rst b/docs/sections/user_guide/api/index.rst index 62669c885..412153c03 100644 --- a/docs/sections/user_guide/api/index.rst +++ b/docs/sections/user_guide/api/index.rst @@ -5,6 +5,7 @@ API chgres_cube config esg_grid + driver file filter_topo fv3 diff --git a/src/uwtools/api/driver.py b/src/uwtools/api/driver.py new file mode 100644 index 000000000..f724ebc7e --- /dev/null +++ b/src/uwtools/api/driver.py @@ -0,0 +1,28 @@ +""" +API access to the ``uwtools`` driver base classes. +""" + +import importlib +import sys + +_CLASSNAMES = [ + "Assets", + "AssetsCycleBased", + "AssetsCycleAndLeadtimeBased", + "AssetsTimeInvariant", + "Driver", + "DriverCycleBased", + "DriverCycleAndLeadtimeBased", + "DriverTimeInvariant", +] + + +def _add_classes(): + m = importlib.import_module("uwtools.drivers.driver") + for classname in _CLASSNAMES: + setattr(sys.modules[__name__], classname, getattr(m, classname)) + __all__.append(classname) + + +__all__: list[str] = [] +_add_classes() diff --git a/src/uwtools/drivers/ww3.py b/src/uwtools/drivers/ww3.py index 864ec2caf..2faf6d64c 100644 --- a/src/uwtools/drivers/ww3.py +++ b/src/uwtools/drivers/ww3.py @@ -14,7 +14,7 @@ class WaveWatchIII(AssetsCycleBased): """ - A library driver for ww3. + An assets driver for ww3. """ # Workflow tasks diff --git a/src/uwtools/tests/api/test_driver.py b/src/uwtools/tests/api/test_driver.py new file mode 100644 index 000000000..9b1115451 --- /dev/null +++ b/src/uwtools/tests/api/test_driver.py @@ -0,0 +1,25 @@ +# pylint: disable=missing-function-docstring,protected-access + +from inspect import isclass, ismodule + +from pytest import mark + +from uwtools.api import driver as driver_api +from uwtools.drivers import driver as driver_lib + + +@mark.parametrize("classname", driver_api._CLASSNAMES) +def test_driver(classname): + assert getattr(driver_api, classname) is getattr(driver_lib, classname) + + +def test_public_attributes(): + # Check that the module is not accidentally exposing unexpected public attributes. Ignore + # private attributes and imported modules and assert that what remains is an intentionally + # exposed (driver) class. + for name in dir(driver_api): + obj = getattr(driver_api, name) + if name.startswith("_") or ismodule(obj): + continue + assert isclass(obj) + assert name in driver_api._CLASSNAMES