From 4bb8f6c720c6ffdf60ba990778e8000ed31a9b87 Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 09:49:12 +0000 Subject: [PATCH 1/9] draft implementation of config. --- ifsbench/__init__.py | 1 + ifsbench/config_mixin.py | 110 +++++++++++ ifsbench/data/extracthandler.py | 19 +- ifsbench/data/namelisthandler.py | 67 +++++-- tests/data/test_extracthandler.py | 17 ++ tests/data/test_namelisthandler.py | 292 ++++++++++++++++++----------- tests/test_config_mixin.py | 214 +++++++++++++++++++++ 7 files changed, 599 insertions(+), 121 deletions(-) create mode 100644 ifsbench/config_mixin.py create mode 100644 tests/test_config_mixin.py diff --git a/ifsbench/__init__.py b/ifsbench/__init__.py index 05ca20a..87f0ccf 100644 --- a/ifsbench/__init__.py +++ b/ifsbench/__init__.py @@ -16,6 +16,7 @@ from .arch import * # noqa from .benchmark import * # noqa from .cli import * # noqa +from .config_mixin import * # noqa from .darshanreport import * # noqa from .drhook import * # noqa from .files import * # noqa diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py new file mode 100644 index 0000000..b484b04 --- /dev/null +++ b/ifsbench/config_mixin.py @@ -0,0 +1,110 @@ +from abc import ABC, abstractmethod +from typing import Any, get_args, get_origin, get_type_hints, Optional, TypeVar, Union + +__all__ = ['ConfigMixin', 'CONF'] + + +CONF = Union[ + int, float, str, bool, dict, list, None +] + + +def _config_from_locals(config: dict[str, Any]) -> None: + print(f'from locals: config={config}, type={type(config)}') + config = config.copy() + config.pop('self', None) + config.pop('cls', None) + return config + + +class ConfigMixin(ABC): + + _config = None + + @classmethod + @abstractmethod + def config_format(cls) -> dict[str, type | dict]: + raise NotImplementedError() + + @classmethod + def _format_from_init(cls) -> dict[str, type | dict]: + format = dict(get_type_hints(cls.__init__, include_extras=False)) + print(f'format initial={format}, type={type(format)}') + format = _config_from_locals(format) + print(f'format cleaned: {format}') + return format + + def set_config_from_init_locals(self, config: dict[str, Any]): + config = _config_from_locals(config) + self.set_config(config) + + def set_config(self, config: dict[str, CONF]) -> None: + if self._config: + raise ValueError('Config already set.') + self._config = config + + def get_config(self) -> dict[str, CONF]: + return self._config + + def update_config(self, field: str, value: CONF) -> None: + if field not in self._config: + raise ValueError(f'{field} not part of config {self._config}, not setting') + if type(value) != type(self._config[field]): + raise ValueError( + f'Cannot update config: wrong type {type(value)} for field {field}' + ) + self._config[field] = value + + @classmethod + def validate_config(cls, config: dict[str, CONF]): + format = cls.config_format() + cls._validate_config_from_format(config, format) + + @classmethod + def _validate_config_from_format( + cls, config: dict[str, CONF], format: dict[str, type | dict] + ): + print(f'config: {config}') + print(f'format: {format}') + + for key, value in config.items(): + if not isinstance(value, CONF): + # check that the given value is a valid config type + raise ValueError(f'Unsupported config value type for {value}') + if key not in format: + raise ValueError(f'unexpected key "{key}" in config, expected {format}') + + for key, value in format.items(): + + if (key not in config) and (type(None) not in get_args(value)): + # format key has to be in config unless it's optional + raise ValueError(f'"{key}" required but not in {config}') + if isinstance(value, dict): + # nested, check that field also nested in config, then recursively check dict. + if not isinstance(config[key], dict): + raise ValueError( + f'"{key}" has type {type(config[key])}, expected {value}' + ) + cls._validate_config_from_format(config[key], format[key]) + elif isinstance(value, list): + # For now, only check both are lists and first entry type is correct, don't check every entry. + if not isinstance(config[key], list): + raise ValueError( + f'"{key}" has type {type(config[key])}, expected {value}' + ) + if type(value[0]) != type(config[key][0]): + raise ValueError( + f'list entries for "{key}" have type {type(config[key][0])}, expected {type(value[0])}' + ) + elif get_origin(value) == Union and type(None) in get_args(value): + # Optional: check matching type or None + opt_type = get_args(value) + if key in config and type(config[key]) not in opt_type: + raise ValueError( + f'wrong type for optional {type(value)}: {config[key]}' + ) + elif type(config[key]) != value: + # types of format and config have to match + raise ValueError( + f'"{key}" has type {type(config[key])}, expected {value}' + ) diff --git a/ifsbench/data/extracthandler.py b/ifsbench/data/extracthandler.py index 046d45d..28ce86d 100644 --- a/ifsbench/data/extracthandler.py +++ b/ifsbench/data/extracthandler.py @@ -7,14 +7,16 @@ import pathlib import shutil +from typing import Optional, Self +from ifsbench.config_mixin import CONF,ConfigMixin from ifsbench.data.datahandler import DataHandler from ifsbench.logging import debug __all__ = ['ExtractHandler'] -class ExtractHandler(DataHandler): +class ExtractHandler(DataHandler,ConfigMixin): """ DataHandler that extracts a given archive to a specific directory. @@ -31,13 +33,26 @@ class ExtractHandler(DataHandler): :meth:`execute`. """ - def __init__(self, archive_path, target_dir=None): + def __init__(self, archive_path: str, target_dir: Optional[str]=None): + self.set_config_from_init_locals(locals()) self._archive_path = pathlib.Path(archive_path) if target_dir is None: self._target_dir = None else: self._target_dir = pathlib.Path(target_dir) + @classmethod + def config_format(cls): + return cls._format_from_init() + + @classmethod + def from_config(cls, config: dict[str,CONF]) -> Self: + cls.validate_config(config) + archive_path = config['archive_path'] + target_dir = config['target_dir'] if 'target_dir' in config else None + return cls(archive_path, target_dir) + + def execute(self, wdir, **kwargs): wdir = pathlib.Path(wdir) diff --git a/ifsbench/data/namelisthandler.py b/ifsbench/data/namelisthandler.py index e757192..5bee8ab 100644 --- a/ifsbench/data/namelisthandler.py +++ b/ifsbench/data/namelisthandler.py @@ -5,23 +5,26 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from enum import auto, Enum +from enum import auto, StrEnum import pathlib +from typing import Optional, Self, Union + import f90nml +from ifsbench.config_mixin import CONF,ConfigMixin from ifsbench.data.datahandler import DataHandler from ifsbench.logging import debug, info __all__ = ['NamelistOverride', 'NamelistHandler', 'NamelistOperation'] -class NamelistOperation(Enum): +class NamelistOperation(StrEnum): SET = auto() APPEND = auto() DELETE = auto() -class NamelistOverride: +class NamelistOverride(ConfigMixin): """ Specify changes that will be applied to a namelist. @@ -43,15 +46,10 @@ class NamelistOverride: """ - def __init__(self, key, mode, value=None): - if isinstance(key, str): - self._keys = key.split('/') - else: - self._keys = tuple(key) - - if len(self._keys) != 2: - raise ValueError("The key object must be of length two.") + def __init__(self, namelist: str, entry: str, mode: NamelistOperation, value: CONF=None): + self.set_config_from_init_locals(locals()) + self._keys = (namelist, entry) self._mode = mode self._value = value @@ -59,6 +57,31 @@ def __init__(self, key, mode, value=None): if self._mode in (NamelistOperation.SET, NamelistOperation.APPEND): raise ValueError("The new value must not be None!") + @classmethod + def from_keytuple(cls, key: tuple[str,str], mode: NamelistOperation, value: CONF=None) -> Self: + if len(key) != 2: + raise ValueError(f"The key tuple must be of length two, found key {key}.") + return cls(key[0], key[1], mode, value) + + @classmethod + def from_keystring(cls, key: str, mode: NamelistOperation, value: CONF=None) -> Self: + keys = key.split('/') + if len(keys) != 2: + raise ValueError(f"The key string must contain single '/', found key {key}.") + return cls(keys[0], keys[1], mode, value) + + @classmethod + def from_config(cls, config: dict[str,CONF]): + cls.validate_config(config) + value = config['value'] if 'value' in config else None + return cls(config['namelist'], config['entry'], config['mode'], value) + + + @classmethod + def config_format(cls): + return cls._format_from_init() + + def apply(self, namelist): """ Apply the stored changes to a namelist. @@ -109,7 +132,7 @@ def apply(self, namelist): debug(f"Delete namelist entry {str(self._keys)}.") del namelist[key] -class NamelistHandler(DataHandler): +class NamelistHandler(DataHandler, ConfigMixin): """ DataHandler specialisation that can modify Fortran namelists. @@ -129,7 +152,10 @@ class NamelistHandler(DataHandler): The NamelistOverrides that will be applied. """ - def __init__(self, input_path, output_path, overrides): + def __init__(self, input_path: str, output_path: str, overrides: list[NamelistOverride]): + + override_confs = [no.get_config() for no in overrides] + self.set_config({'input_path': input_path, 'output_path': output_path, 'overrides': override_confs}) self._input_path = pathlib.Path(input_path) self._output_path = pathlib.Path(output_path) @@ -139,6 +165,21 @@ def __init__(self, input_path, output_path, overrides): if not isinstance(override, NamelistOverride): raise ValueError("Namelist overrides must be NamelistOverride objects!") + @classmethod + def config_format(cls) -> dict[str,type|dict]: + return {'input_path': str, 'output_path': str, 'overrides': [{str: CONF, }, ]} + + + @classmethod + def from_config(cls, config: dict[str,CONF]) -> Self: + cls.validate_config(config) + input_path = config['input_path'] + output_path = config['output_path'] + override_configs = config['overrides'] + overrides = [NamelistOverride.from_config(oc) for oc in override_configs] + return cls(input_path, output_path, overrides) + + def execute(self, wdir, **kwargs): wdir = pathlib.Path(wdir) diff --git a/tests/data/test_extracthandler.py b/tests/data/test_extracthandler.py index d3ad6bf..88a6c1b 100644 --- a/tests/data/test_extracthandler.py +++ b/tests/data/test_extracthandler.py @@ -146,3 +146,20 @@ def test_extracthandler_execute(tmp_path, archive, archive_path, archive_relativ for path in archive: assert (extract_path/path).exists() + + +def test_from_config_succeeds(): + in_conf = {'archive_path': 'arch/path', 'target_dir': 'target/dir'} + eh = ExtractHandler.from_config(in_conf) + out_conf = eh.get_config() + + assert out_conf == in_conf + + +def test_from_config_target_dir_null_succeeds(): + in_conf = {'archive_path': 'arch/path'} + eh = ExtractHandler.from_config(in_conf) + out_conf = eh.get_config() + + expected = {'archive_path': 'arch/path', 'target_dir': None} + assert out_conf == expected \ No newline at end of file diff --git a/tests/data/test_namelisthandler.py b/tests/data/test_namelisthandler.py index a2eec1a..159b0bc 100644 --- a/tests/data/test_namelisthandler.py +++ b/tests/data/test_namelisthandler.py @@ -15,38 +15,36 @@ from f90nml import Namelist import pytest -from ifsbench.data import ( - NamelistHandler, NamelistOverride, NamelistOperation -) +from ifsbench.data import NamelistHandler, NamelistOverride, NamelistOperation + -@pytest.fixture(name = 'initial_namelist') +@pytest.fixture(name='initial_namelist') def fixture_namelist(): namelist = Namelist() - namelist['namelist1'] = { - 'int': 2, - 'str': 'test', - 'list': [2, 3, 'entry'] - } + namelist['namelist1'] = {'int': 2, 'str': 'test', 'list': [2, 3, 'entry']} namelist['namelist2'] = {'int': 5} return namelist -@pytest.mark.parametrize('key,mode,value,success', [ - ('namelist1', NamelistOperation.APPEND, None, False), - ('namelist1', NamelistOperation.SET, None, False), - ('namelist1', NamelistOperation.DELETE, None, False), - ('namelist1/entry', NamelistOperation.DELETE, None, True), - ('namelist1/entry', NamelistOperation.SET, None, False), - ('namelist1/entry', NamelistOperation.APPEND, None, False), - ('namelist1/entry', NamelistOperation.SET, 2, True), - ('namelist1/entry', NamelistOperation.APPEND, 3, True), - (('namelist1', 'entry'), NamelistOperation.SET, 2, True), - (('namelist1', 'entry'), NamelistOperation.APPEND, 3, True), -]) -def test_extracthandler_init(key, mode, value, success): +@pytest.mark.parametrize( + 'key,mode,value,success', + [ + ('namelist1', NamelistOperation.APPEND, None, False), + ('namelist1', NamelistOperation.SET, None, False), + ('namelist1', NamelistOperation.DELETE, None, False), + ('namelist1/entry', NamelistOperation.DELETE, None, True), + ('namelist1/entry', NamelistOperation.SET, None, False), + ('namelist1/entry', NamelistOperation.APPEND, None, False), + ('namelist1/entry', NamelistOperation.SET, 2, True), + ('namelist1/entry', NamelistOperation.APPEND, 3, True), + (('namelist1', 'entry'), NamelistOperation.SET, 2, True), + (('namelist1', 'entry'), NamelistOperation.APPEND, 3, True), + ], +) +def test_namelistoverride_init(key, mode, value, success): """ Initialise the NamelistOverride and make sure that only correct values are accepted. @@ -58,24 +56,68 @@ def test_extracthandler_init(key, mode, value, success): context = pytest.raises(ValueError) with context: - NamelistOverride(key, mode, value) + if isinstance(key, tuple): + NamelistOverride.from_keytuple(key, mode, value) + else: + NamelistOverride.from_keystring(key, mode, value) + + +@pytest.mark.parametrize( + 'key,mode,value', + [ + (('namelist1', 'entry'), NamelistOperation.DELETE, None), + (('namelist1', 'entry'), NamelistOperation.SET, 2), + (('namelist1', 'entry'), NamelistOperation.APPEND, 3), + ], +) +def test_namelistoverride_from_config(key, mode, value): + in_conf = {'namelist': key[0], 'entry': key[1], 'mode': mode} + if value: + in_conf['value'] = value + + nov = NamelistOverride.from_config(in_conf) + out_conf = nov.get_config() + expected = in_conf.copy() + expected['value'] = value + assert out_conf == expected -@pytest.mark.parametrize('key,value', [ - (('namelist1', 'int'), 5), - (('namelist1', 'list'), [0, 2]), - (('namelist2', 'int'), 'not an int'), - (('namelist2', 'newvalue'), 5), - (('namelist3', 'anothervalue'), [2,3,4]), -]) -def test_extracthandler_apply_set(initial_namelist, key, value): + +@pytest.mark.parametrize( + 'key,mode,value', + [ + (('namelist1', 'entry'), NamelistOperation.DELETE, None), + (('namelist1', 'entry'), NamelistOperation.SET, 2), + (('namelist1', 'entry'), NamelistOperation.APPEND, 3), + ], +) +def test_namelistoverride_get_config(key, mode, value): + nov = NamelistOverride(key[0], key[1], mode, value) + + conf = nov.get_config() + + expected = {'namelist': key[0], 'entry': key[1], 'mode': mode, 'value': value} + assert conf == expected + + +@pytest.mark.parametrize( + 'key,value', + [ + (('namelist1', 'int'), 5), + (('namelist1', 'list'), [0, 2]), + (('namelist2', 'int'), 'not an int'), + (('namelist2', 'newvalue'), 5), + (('namelist3', 'anothervalue'), [2, 3, 4]), + ], +) +def test_namelistoverride_apply_set(initial_namelist, key, value): """ Initialise the NamelistOverride and make sure that only correct values are accepted. """ namelist = Namelist(initial_namelist) - override = NamelistOverride(key, NamelistOperation.SET, value) + override = NamelistOverride(key[0], key[1], NamelistOperation.SET, value) override.apply(namelist) @@ -86,23 +128,27 @@ def test_extracthandler_apply_set(initial_namelist, key, value): if (name, name2) != key: assert entry[name2] == initial_namelist[name][name2] -@pytest.mark.parametrize('key,value,success', [ - (('namelist1', 'int'), 5, False), - (('namelist1', 'list'), 3, True), - (('namelist1', 'list'), [2, 4], False), - (('namelist1', 'list'), 5, True), - (('namelist1', 'list'), 'Hello', False), - (('namelist2', 'int'), 'not an int', False), - (('namelist3', 'new_list'), 'not an int', True) -]) -def test_extracthandler_apply_append(initial_namelist, key, value, success): + +@pytest.mark.parametrize( + 'key,value,success', + [ + (('namelist1', 'int'), 5, False), + (('namelist1', 'list'), 3, True), + (('namelist1', 'list'), [2, 4], False), + (('namelist1', 'list'), 5, True), + (('namelist1', 'list'), 'Hello', False), + (('namelist2', 'int'), 'not an int', False), + (('namelist3', 'new_list'), 'not an int', True), + ], +) +def test_namelistoverride_apply_append(initial_namelist, key, value, success): """ Initialise the NamelistOverride and make sure that only correct values are accepted. """ namelist = Namelist(initial_namelist) - override = NamelistOverride(key, NamelistOperation.APPEND, value) + override = NamelistOverride(key[0], key[1], NamelistOperation.APPEND, value) if success: override.apply(namelist) @@ -116,29 +162,31 @@ def test_extracthandler_apply_append(initial_namelist, key, value, success): else: assert namelist[key[0]][key[1]] == [value] - for name, entry in namelist.items(): for name2 in entry.keys(): if (name, name2) != key: assert entry[name2] == initial_namelist[name][name2] -@pytest.mark.parametrize('key', [ - ('namelist1', 'int'), - ('namelist1', 'list'), - ('namelist1', 'list'), - ('namelist2', 'int'), - ('doesnot', 'exist'), - ('namelist1', 'missing'), -]) -def test_extracthandler_apply_delete(initial_namelist, key): +@pytest.mark.parametrize( + 'key', + [ + ('namelist1', 'int'), + ('namelist1', 'list'), + ('namelist1', 'list'), + ('namelist2', 'int'), + ('doesnot', 'exist'), + ('namelist1', 'missing'), + ], +) +def test_namelistoverride_apply_delete(initial_namelist, key): """ Initialise the NamelistOverride and make sure that only correct values are accepted. """ namelist = Namelist(initial_namelist) - override = NamelistOverride(key, NamelistOperation.DELETE) + override = NamelistOverride(key[0], key[1], NamelistOperation.DELETE) override.apply(namelist) @@ -151,31 +199,44 @@ def test_extracthandler_apply_delete(initial_namelist, key): assert namelist[name][name2] == initial_namelist[name][name2] -@pytest.mark.parametrize('input_path,input_valid', [ - (Path('somewhere/fort.4'), True), - ('somewhere/namelist', True), - (None, False), - (2, False) -]) -@pytest.mark.parametrize('output_path,output_valid', [ - (Path('somewhere/new_fort.4'), True), - ('somewhere/namelist', True), - (None, False), - (2, False) -]) -@pytest.mark.parametrize('overrides, overrides_valid', [ - ([], True), - ('Test', False), - (2, False), - ([NamelistOverride('namelist/entry', NamelistOperation.SET, 5)], True), - ([ - NamelistOverride('namelist/entry', NamelistOperation.SET, 5), - NamelistOverride('namelist/entry2', NamelistOperation.APPEND, 2), - NamelistOverride('namelist/entry', NamelistOperation.DELETE), - - ], True), -]) -def test_namelisthandler_init(input_path, input_valid, output_path, output_valid, overrides, overrides_valid): +@pytest.mark.parametrize( + 'input_path,input_valid', + [ + (Path('somewhere/fort.4'), True), + ('somewhere/namelist', True), + (None, False), + (2, False), + ], +) +@pytest.mark.parametrize( + 'output_path,output_valid', + [ + (Path('somewhere/new_fort.4'), True), + ('somewhere/namelist', True), + (None, False), + (2, False), + ], +) +@pytest.mark.parametrize( + 'overrides, overrides_valid', + [ + ([], True), + ('Test', False), + (2, False), + ([NamelistOverride('namelist', 'entry', NamelistOperation.SET, 5)], True), + ( + [ + NamelistOverride('namelist', 'entry', NamelistOperation.SET, 5), + NamelistOverride('namelist', 'entry2', NamelistOperation.APPEND, 2), + NamelistOverride('namelist', 'entry', NamelistOperation.DELETE), + ], + True, + ), + ], +) +def test_namelisthandler_init( + input_path, input_valid, output_path, output_valid, overrides, overrides_valid +): """ Initialise the NamelistHandler and make sure that only correct values are accepted. """ @@ -188,31 +249,51 @@ def test_namelisthandler_init(input_path, input_valid, output_path, output_valid NamelistHandler(input_path, output_path, overrides) +def test_namelisthandler_from_config_get_config(): + config = { + 'input_path': 'in_path', + 'output_path': 'out_path', + 'overrides': [ + {'namelist': 'nl1', 'entry': 'e1', 'mode': NamelistOperation.SET, 'value': 5}, + {'namelist': 'nl2', 'entry': 'e2', 'mode': NamelistOperation.DELETE}, + ], + } + nh = NamelistHandler.from_config(config) + expected = config.copy() + expected['overrides'][1]['value'] = None + assert nh.get_config() == config -@pytest.mark.parametrize('input_path', [ - Path('somewhere/fort.4'), - 'somewhere/namelist' -]) +@pytest.mark.parametrize('input_path', [Path('somewhere/fort.4'), 'somewhere/namelist']) @pytest.mark.parametrize('input_relative', [True, False]) -@pytest.mark.parametrize('output_path', [ - Path('somewhere_else/new_fort.4'), - 'somewhere/namelist', -]) +@pytest.mark.parametrize( + 'output_path', + [ + Path('somewhere_else/new_fort.4'), + 'somewhere/namelist', + ], +) @pytest.mark.parametrize('output_relative', [True, False]) -@pytest.mark.parametrize('overrides', [ - [], - [NamelistOverride('namelist/entry', NamelistOperation.SET, 5)], +@pytest.mark.parametrize( + 'overrides', [ - NamelistOverride('namelist/entry', NamelistOperation.SET, 5), - NamelistOverride('namelist/entry2', NamelistOperation.APPEND, 2), - NamelistOverride('namelist/entry', NamelistOperation.DELETE), - + [], + [NamelistOverride('namelist', 'entry', NamelistOperation.SET, 5)], + [ + NamelistOverride('namelist', 'entry', NamelistOperation.SET, 5), + NamelistOverride('namelist', 'entry2', NamelistOperation.APPEND, 2), + NamelistOverride('namelist', 'entry', NamelistOperation.DELETE), + ], ], -]) - -def test_namelisthandler_execute(tmp_path, initial_namelist, input_path, - input_relative, output_path, output_relative, - overrides): +) +def test_namelisthandler_execute( + tmp_path, + initial_namelist, + input_path, + input_relative, + output_path, + output_relative, + overrides, +): """ Test that the execute function modifies the namelists correctly. @@ -245,29 +326,28 @@ def test_namelisthandler_execute(tmp_path, initial_namelist, input_path, # both). if not input_relative: if isinstance(input_path, str): - input_path = str((tmp_path/input_path).resolve()) + input_path = str((tmp_path / input_path).resolve()) else: - input_path = (tmp_path/input_path).resolve() + input_path = (tmp_path / input_path).resolve() if not output_relative: if isinstance(output_path, str): - output_path = str((tmp_path/output_path).resolve()) + output_path = str((tmp_path / output_path).resolve()) else: - output_path = (tmp_path/output_path).resolve() + output_path = (tmp_path / output_path).resolve() # Create the initial namelist. - abs_input_path = tmp_path/output_path + abs_input_path = tmp_path / output_path abs_input_path.parent.mkdir(parents=True, exist_ok=True) initial_namelist.write(abs_input_path) - # Actually extract the archive. handler = NamelistHandler(input_path, output_path, overrides) handler.execute(tmp_path) if output_relative: - assert (tmp_path/output_path).exists() + assert (tmp_path / output_path).exists() else: assert Path(output_path).exists() diff --git a/tests/test_config_mixin.py b/tests/test_config_mixin.py new file mode 100644 index 0000000..461b496 --- /dev/null +++ b/tests/test_config_mixin.py @@ -0,0 +1,214 @@ +import pytest +from typing import Optional + +from ifsbench import ConfigMixin + + +class TestConfigFromLocals(ConfigMixin): + def __init__(self, field1: int, field2: float, field3: str): + self.set_config_from_init_locals(locals()) + + @classmethod + def config_format(cls): + return cls._format_from_init() + + +class TestConfigSet(ConfigMixin): + def __init__(self, field1: int, field2: float, field3: str): + config = {'field1': field1, 'field2': field2, 'field3': field3} + self.set_config(config) + + @classmethod + def config_format(cls): + del cls + return {'field1': type(int), 'field2': type(float), 'field3': type(str)} + + +class TestConfigNestedConfigFormat(ConfigMixin): + + @classmethod + def config_format(cls): + del cls + return {'field1': int, 'field2': float, 'field3': {'field3a': str, 'field3b': int}} + + +class TestConfigOptional(ConfigMixin): + def __init__(self, field1: int, field2: Optional[str] = None): + self.set_config_from_init_locals(locals()) + + @classmethod + def config_format(cls): + return cls._format_from_init() + + +class TestConfigList(ConfigMixin): + + @classmethod + def config_format(cls): + del cls + return {'field1': int, 'field2': float, 'field3': [{str: int, }, ]} + + +VALUE1 = 3 +VALUE2 = 3.1 +VALUE3 = 'some/path' + + +def test_set_config_succeeds(): + + tc = TestConfigSet(field1=VALUE1, field2=VALUE2, field3=VALUE3) + config = tc.get_config() + + expected = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3} + assert(config == expected) + + +def test_set_config_from_init_locals_succeeds(): + + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + config = tc.get_config() + + expected = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3} + assert(config == expected) + + +def test_set_config_from_init_optional_set_succeeds(): + + tc = TestConfigOptional(field1=VALUE1, field2=VALUE3) + conf = tc.get_config() + + expected = {'field1': VALUE1, 'field2': VALUE3} + assert conf == expected + + +def test_set_config_from_init_optional_none_succeeds(): + + tc = TestConfigOptional(field1=VALUE1) + config = tc.get_config() + + expected = {'field1': VALUE1, 'field2': None} + assert(config == expected) + + +def test_set_config_already_set_fails(): + + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + with pytest.raises(ValueError): + tc.set_config({'something': 'other'}) + + +def test_update_config_succeeds(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + + tc.update_config(field='field1', value=4) + config = tc.get_config() + + expected = {'field1': 4, 'field2': VALUE2, 'field3': VALUE3} + assert(config == expected) + + +def test_update_config_add_field_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + + with pytest.raises(ValueError) as exceptinfo: + tc.update_config(field='field4', value=4) + assert str(exceptinfo.value) == f'field4 not part of config {tc.get_config()}, not setting' + + +def test_update_config_wrong_type_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + + with pytest.raises(ValueError) as exceptinfo: + tc.update_config(field='field1', value='should be int') + assert str(exceptinfo.value) == 'Cannot update config: wrong type for field field1' + + +def test_validate_config_succeeds(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3} + tc.validate_config(config=to_validate) + + +def test_validate_config_unsupported_type_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + to_validate = {'field1': set(), 'field2': VALUE2, 'field3': VALUE3} + with pytest.raises(ValueError) as exceptinfo: + tc.validate_config(config=to_validate) + assert str(exceptinfo.value) == f'Unsupported config value type for {set()}' + + +def test_validate_config_wrong_type_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + to_validate = {'field1': 'some string', 'field2': VALUE2, 'field3': VALUE3} + with pytest.raises(ValueError) as exceptinfo: + tc.validate_config(config=to_validate) + assert str(exceptinfo.value) == '"field1" has type , expected ' + + +def test_validate_config_field_not_in_config_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + to_validate = {'field1': VALUE1, 'field2': VALUE2} + with pytest.raises(ValueError) as exceptinfo: + tc.validate_config(config=to_validate) + assert str(exceptinfo.value) == f'"field3" required but not in {to_validate}' + + +def test_validate_config_field_not_in_format_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3, 'field4': 'unexpected field'} + with pytest.raises(ValueError) as exceptinfo: + tc.validate_config(config=to_validate) + assert str(exceptinfo.value) == f'unexpected key "field4" in config, expected {tc.config_format()}' + + +def test_validate_config_nested_succeedss(): + tc = TestConfigNestedConfigFormat() + to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': {'field3a': 'path', 'field3b': 42}} + tc.validate_config(config=to_validate) + + +def test_validate_config_nested_dict_mismatch_fails(): + tc = TestConfigNestedConfigFormat() + to_validate = {'field1': VALUE1, 'field2': {'field2a': 4.4}, 'field3': {'field3a': 'path', 'field3b': 42}} + with pytest.raises(ValueError) as exceptinfo: + tc.validate_config(config=to_validate) + assert str(exceptinfo.value) == '"field2" has type , expected ' + + +def test_validate_config_nested_config_not_in_format_fails(): + tc = TestConfigNestedConfigFormat() + to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': {'field3a': 'path', 'field3b': 42, 'field3c': 'surplus'}} + with pytest.raises(ValueError) as exceptinfo: + tc.validate_config(config=to_validate) + expected = tc.config_format()['field3'] + assert str(exceptinfo.value) == f'unexpected key "field3c" in config, expected {expected}' + + +def test_validate_config_optional_set_succeeds(): + + tc = TestConfigOptional(field1=VALUE1, field2=VALUE3) + to_validate = {'field1': VALUE1, 'field2': VALUE3} + + tc.validate_config(to_validate) + + +def test_validate_config_optional_not_given_succeeds(): + + tc = TestConfigOptional(field1=VALUE1, field2=VALUE3) + to_validate = {'field1': VALUE1} + + tc.validate_config(to_validate) + + +def test_validate_config_list_succeeds(): + tc = TestConfigList() + to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': [{'field3a': 'path', 'field3b': 'another path'}, {'field3a': 'path2', 'field3b': 'another path2'}]} + tc.validate_config(config=to_validate) + + +def test_validate_config_list_wrong_type_fails(): + tc = TestConfigList() + to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': ['path', 'another path']} + with pytest.raises(ValueError) as exceptinfo: + tc.validate_config(config=to_validate) + assert str(exceptinfo.value) == f'list entries for "field3" have type , expected ' From f2a3474af0c395a24e9e8cfdfb66ba2611dd0ed6 Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 09:53:30 +0000 Subject: [PATCH 2/9] applied formatting --- ifsbench/config_mixin.py | 6 +- ifsbench/data/extracthandler.py | 11 ++- ifsbench/data/namelisthandler.py | 62 +++++++++++----- tests/data/test_extracthandler.py | 96 +++++++++++++++---------- tests/data/test_namelisthandler.py | 8 ++- tests/test_config_mixin.py | 109 ++++++++++++++++++++++------- 6 files changed, 200 insertions(+), 92 deletions(-) diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py index b484b04..d74c675 100644 --- a/ifsbench/config_mixin.py +++ b/ifsbench/config_mixin.py @@ -4,9 +4,7 @@ __all__ = ['ConfigMixin', 'CONF'] -CONF = Union[ - int, float, str, bool, dict, list, None -] +CONF = Union[int, float, str, bool, dict, list, None] def _config_from_locals(config: dict[str, Any]) -> None: @@ -74,7 +72,7 @@ def _validate_config_from_format( if key not in format: raise ValueError(f'unexpected key "{key}" in config, expected {format}') - for key, value in format.items(): + for key, value in format.items(): if (key not in config) and (type(None) not in get_args(value)): # format key has to be in config unless it's optional diff --git a/ifsbench/data/extracthandler.py b/ifsbench/data/extracthandler.py index 28ce86d..6048d0e 100644 --- a/ifsbench/data/extracthandler.py +++ b/ifsbench/data/extracthandler.py @@ -9,14 +9,14 @@ import shutil from typing import Optional, Self -from ifsbench.config_mixin import CONF,ConfigMixin +from ifsbench.config_mixin import CONF, ConfigMixin from ifsbench.data.datahandler import DataHandler from ifsbench.logging import debug __all__ = ['ExtractHandler'] -class ExtractHandler(DataHandler,ConfigMixin): +class ExtractHandler(DataHandler, ConfigMixin): """ DataHandler that extracts a given archive to a specific directory. @@ -33,7 +33,7 @@ class ExtractHandler(DataHandler,ConfigMixin): :meth:`execute`. """ - def __init__(self, archive_path: str, target_dir: Optional[str]=None): + def __init__(self, archive_path: str, target_dir: Optional[str] = None): self.set_config_from_init_locals(locals()) self._archive_path = pathlib.Path(archive_path) if target_dir is None: @@ -46,13 +46,12 @@ def config_format(cls): return cls._format_from_init() @classmethod - def from_config(cls, config: dict[str,CONF]) -> Self: + def from_config(cls, config: dict[str, CONF]) -> Self: cls.validate_config(config) archive_path = config['archive_path'] target_dir = config['target_dir'] if 'target_dir' in config else None return cls(archive_path, target_dir) - def execute(self, wdir, **kwargs): wdir = pathlib.Path(wdir) @@ -61,7 +60,7 @@ def execute(self, wdir, **kwargs): if self._target_dir.is_absolute(): target_dir = self._target_dir else: - target_dir = wdir/self._target_dir + target_dir = wdir / self._target_dir debug(f"Unpack archive {self._archive_path} to {target_dir}.") shutil.unpack_archive(self._archive_path, target_dir) diff --git a/ifsbench/data/namelisthandler.py b/ifsbench/data/namelisthandler.py index 5bee8ab..c674660 100644 --- a/ifsbench/data/namelisthandler.py +++ b/ifsbench/data/namelisthandler.py @@ -12,18 +12,20 @@ import f90nml -from ifsbench.config_mixin import CONF,ConfigMixin +from ifsbench.config_mixin import CONF, ConfigMixin from ifsbench.data.datahandler import DataHandler from ifsbench.logging import debug, info __all__ = ['NamelistOverride', 'NamelistHandler', 'NamelistOperation'] + class NamelistOperation(StrEnum): SET = auto() APPEND = auto() DELETE = auto() + class NamelistOverride(ConfigMixin): """ Specify changes that will be applied to a namelist. @@ -45,8 +47,9 @@ class NamelistOverride(ConfigMixin): The value that is set (SET operation) or appended (APPEND). """ - - def __init__(self, namelist: str, entry: str, mode: NamelistOperation, value: CONF=None): + def __init__( + self, namelist: str, entry: str, mode: NamelistOperation, value: CONF = None + ): self.set_config_from_init_locals(locals()) self._keys = (namelist, entry) @@ -58,29 +61,33 @@ def __init__(self, namelist: str, entry: str, mode: NamelistOperation, value: CO raise ValueError("The new value must not be None!") @classmethod - def from_keytuple(cls, key: tuple[str,str], mode: NamelistOperation, value: CONF=None) -> Self: + def from_keytuple( + cls, key: tuple[str, str], mode: NamelistOperation, value: CONF = None + ) -> Self: if len(key) != 2: raise ValueError(f"The key tuple must be of length two, found key {key}.") return cls(key[0], key[1], mode, value) @classmethod - def from_keystring(cls, key: str, mode: NamelistOperation, value: CONF=None) -> Self: + def from_keystring( + cls, key: str, mode: NamelistOperation, value: CONF = None + ) -> Self: keys = key.split('/') if len(keys) != 2: - raise ValueError(f"The key string must contain single '/', found key {key}.") + raise ValueError( + f"The key string must contain single '/', found key {key}." + ) return cls(keys[0], keys[1], mode, value) @classmethod - def from_config(cls, config: dict[str,CONF]): + def from_config(cls, config: dict[str, CONF]): cls.validate_config(config) value = config['value'] if 'value' in config else None return cls(config['namelist'], config['entry'], config['mode'], value) - @classmethod def config_format(cls): return cls._format_from_init() - def apply(self, namelist): """ @@ -121,7 +128,9 @@ def apply(self, namelist): type_value = type(self._value) if type_list != type_value: - raise ValueError("The given value must have the same type as existing array entries!") + raise ValueError( + "The given value must have the same type as existing array entries!" + ) debug(f"Append {str(self._value)} to namelist entry {str(self._keys)}.") @@ -132,6 +141,7 @@ def apply(self, namelist): debug(f"Delete namelist entry {str(self._keys)}.") del namelist[key] + class NamelistHandler(DataHandler, ConfigMixin): """ DataHandler specialisation that can modify Fortran namelists. @@ -152,10 +162,18 @@ class NamelistHandler(DataHandler, ConfigMixin): The NamelistOverrides that will be applied. """ - def __init__(self, input_path: str, output_path: str, overrides: list[NamelistOverride]): + def __init__( + self, input_path: str, output_path: str, overrides: list[NamelistOverride] + ): override_confs = [no.get_config() for no in overrides] - self.set_config({'input_path': input_path, 'output_path': output_path, 'overrides': override_confs}) + self.set_config( + { + 'input_path': input_path, + 'output_path': output_path, + 'overrides': override_confs, + } + ) self._input_path = pathlib.Path(input_path) self._output_path = pathlib.Path(output_path) @@ -166,12 +184,19 @@ def __init__(self, input_path: str, output_path: str, overrides: list[NamelistOv raise ValueError("Namelist overrides must be NamelistOverride objects!") @classmethod - def config_format(cls) -> dict[str,type|dict]: - return {'input_path': str, 'output_path': str, 'overrides': [{str: CONF, }, ]} - + def config_format(cls) -> dict[str, type | dict]: + return { + 'input_path': str, + 'output_path': str, + 'overrides': [ + { + str: CONF, + }, + ], + } @classmethod - def from_config(cls, config: dict[str,CONF]) -> Self: + def from_config(cls, config: dict[str, CONF]) -> Self: cls.validate_config(config) input_path = config['input_path'] output_path = config['output_path'] @@ -179,14 +204,13 @@ def from_config(cls, config: dict[str,CONF]) -> Self: overrides = [NamelistOverride.from_config(oc) for oc in override_configs] return cls(input_path, output_path, overrides) - def execute(self, wdir, **kwargs): wdir = pathlib.Path(wdir) if self._input_path.is_absolute(): input_path = self._input_path else: - input_path = wdir/self._input_path + input_path = wdir / self._input_path # Do nothing if the input namelist doesn't exist. if not input_path.exists(): @@ -196,7 +220,7 @@ def execute(self, wdir, **kwargs): if self._output_path.is_absolute(): output_path = self._output_path else: - output_path = wdir/self._output_path + output_path = wdir / self._output_path debug(f"Modify namelist {input_path}.") namelist = f90nml.read(input_path) diff --git a/tests/data/test_extracthandler.py b/tests/data/test_extracthandler.py index 88a6c1b..26a9b56 100644 --- a/tests/data/test_extracthandler.py +++ b/tests/data/test_extracthandler.py @@ -15,22 +15,27 @@ import pytest -from ifsbench.data import ( - ExtractHandler -) +from ifsbench.data import ExtractHandler + -@pytest.mark.parametrize('archive_path,archive_valid', [ - (Path('somewhere/archive.tar'), True), - ('somewhere/archive.tar', True), - (None, False), - (2, False) -]) -@pytest.mark.parametrize('target_dir, target_valid', [ - (Path('somewhere/archive.tar'), True), - ('somewhere/archive.tar', True), - (None, True), - (2, False) -]) +@pytest.mark.parametrize( + 'archive_path,archive_valid', + [ + (Path('somewhere/archive.tar'), True), + ('somewhere/archive.tar', True), + (None, False), + (2, False), + ], +) +@pytest.mark.parametrize( + 'target_dir, target_valid', + [ + (Path('somewhere/archive.tar'), True), + ('somewhere/archive.tar', True), + (None, True), + (2, False), + ], +) def test_extracthandler_init(archive_path, archive_valid, target_dir, target_valid): """ Initialise the ExtractHandler and make sure that only correct values are accepted. @@ -43,6 +48,7 @@ def test_extracthandler_init(archive_path, archive_valid, target_dir, target_val with context: ExtractHandler(archive_path, target_dir) + @pytest.fixture(name='archive') def fixture_archive(): paths = [ @@ -55,21 +61,33 @@ def fixture_archive(): return paths -@pytest.mark.parametrize('archive_path', [ - Path('somewhere/archive'), - 'somewhere/archive', -]) +@pytest.mark.parametrize( + 'archive_path', + [ + Path('somewhere/archive'), + 'somewhere/archive', + ], +) @pytest.mark.parametrize('archive_relative', [True, False]) @pytest.mark.parametrize('archive_type', ['zip', 'tar', 'gztar']) - -@pytest.mark.parametrize('target_dir', [ - Path('somewhere/extract'), - 'somewhere/extract', - None, -]) +@pytest.mark.parametrize( + 'target_dir', + [ + Path('somewhere/extract'), + 'somewhere/extract', + None, + ], +) @pytest.mark.parametrize('target_relative', [True, False]) -def test_extracthandler_execute(tmp_path, archive, archive_path, archive_relative, - archive_type, target_dir, target_relative): +def test_extracthandler_execute( + tmp_path, + archive, + archive_path, + archive_relative, + archive_type, + target_dir, + target_relative, +): """ Test that the execute function moves the content of an archive to the right directory. @@ -107,27 +125,29 @@ def test_extracthandler_execute(tmp_path, archive, archive_path, archive_relativ # both). if not archive_relative: if isinstance(archive_path, str): - archive_path = str((tmp_path/archive_path).resolve()) + archive_path = str((tmp_path / archive_path).resolve()) else: - archive_path = (tmp_path/archive_path).resolve() + archive_path = (tmp_path / archive_path).resolve() if not target_relative and target_dir is not None: if isinstance(archive_path, str): - target_dir = str((tmp_path/target_dir).resolve()) + target_dir = str((tmp_path / target_dir).resolve()) else: - target_dir = (tmp_path/target_dir).resolve() + target_dir = (tmp_path / target_dir).resolve() # Build the archive that we will unpack by using pack_path as a directory # that we will compress. Simply touch each file in fixture_archive. - pack_path = tmp_path/'pack' + pack_path = tmp_path / 'pack' for path in archive: - (pack_path/path).parent.mkdir(parents=True, exist_ok=True) - (pack_path/path).touch() + (pack_path / path).parent.mkdir(parents=True, exist_ok=True) + (pack_path / path).touch() if Path(archive_path).is_absolute(): archive_path = shutil.make_archive(archive_path, archive_type, pack_path) else: - archive_path = shutil.make_archive(tmp_path/archive_path, archive_type, pack_path) + archive_path = shutil.make_archive( + tmp_path / archive_path, archive_type, pack_path + ) # Actually extract the archive. handler = ExtractHandler(archive_path, target_dir) @@ -142,10 +162,10 @@ def test_extracthandler_execute(tmp_path, archive, archive_path, archive_relativ extract_path = Path(target_dir) if not extract_path.is_absolute(): - extract_path = tmp_path/extract_path + extract_path = tmp_path / extract_path for path in archive: - assert (extract_path/path).exists() + assert (extract_path / path).exists() def test_from_config_succeeds(): @@ -162,4 +182,4 @@ def test_from_config_target_dir_null_succeeds(): out_conf = eh.get_config() expected = {'archive_path': 'arch/path', 'target_dir': None} - assert out_conf == expected \ No newline at end of file + assert out_conf == expected diff --git a/tests/data/test_namelisthandler.py b/tests/data/test_namelisthandler.py index 159b0bc..2ad0ffd 100644 --- a/tests/data/test_namelisthandler.py +++ b/tests/data/test_namelisthandler.py @@ -254,7 +254,12 @@ def test_namelisthandler_from_config_get_config(): 'input_path': 'in_path', 'output_path': 'out_path', 'overrides': [ - {'namelist': 'nl1', 'entry': 'e1', 'mode': NamelistOperation.SET, 'value': 5}, + { + 'namelist': 'nl1', + 'entry': 'e1', + 'mode': NamelistOperation.SET, + 'value': 5, + }, {'namelist': 'nl2', 'entry': 'e2', 'mode': NamelistOperation.DELETE}, ], } @@ -263,6 +268,7 @@ def test_namelisthandler_from_config_get_config(): expected['overrides'][1]['value'] = None assert nh.get_config() == config + @pytest.mark.parametrize('input_path', [Path('somewhere/fort.4'), 'somewhere/namelist']) @pytest.mark.parametrize('input_relative', [True, False]) @pytest.mark.parametrize( diff --git a/tests/test_config_mixin.py b/tests/test_config_mixin.py index 461b496..3853036 100644 --- a/tests/test_config_mixin.py +++ b/tests/test_config_mixin.py @@ -29,7 +29,11 @@ class TestConfigNestedConfigFormat(ConfigMixin): @classmethod def config_format(cls): del cls - return {'field1': int, 'field2': float, 'field3': {'field3a': str, 'field3b': int}} + return { + 'field1': int, + 'field2': float, + 'field3': {'field3a': str, 'field3b': int}, + } class TestConfigOptional(ConfigMixin): @@ -46,7 +50,15 @@ class TestConfigList(ConfigMixin): @classmethod def config_format(cls): del cls - return {'field1': int, 'field2': float, 'field3': [{str: int, }, ]} + return { + 'field1': int, + 'field2': float, + 'field3': [ + { + str: int, + }, + ], + } VALUE1 = 3 @@ -58,18 +70,18 @@ def test_set_config_succeeds(): tc = TestConfigSet(field1=VALUE1, field2=VALUE2, field3=VALUE3) config = tc.get_config() - + expected = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3} - assert(config == expected) + assert config == expected def test_set_config_from_init_locals_succeeds(): tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) config = tc.get_config() - + expected = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3} - assert(config == expected) + assert config == expected def test_set_config_from_init_optional_set_succeeds(): @@ -85,9 +97,9 @@ def test_set_config_from_init_optional_none_succeeds(): tc = TestConfigOptional(field1=VALUE1) config = tc.get_config() - + expected = {'field1': VALUE1, 'field2': None} - assert(config == expected) + assert config == expected def test_set_config_already_set_fails(): @@ -102,9 +114,9 @@ def test_update_config_succeeds(): tc.update_config(field='field1', value=4) config = tc.get_config() - + expected = {'field1': 4, 'field2': VALUE2, 'field3': VALUE3} - assert(config == expected) + assert config == expected def test_update_config_add_field_fails(): @@ -112,7 +124,10 @@ def test_update_config_add_field_fails(): with pytest.raises(ValueError) as exceptinfo: tc.update_config(field='field4', value=4) - assert str(exceptinfo.value) == f'field4 not part of config {tc.get_config()}, not setting' + assert ( + str(exceptinfo.value) + == f'field4 not part of config {tc.get_config()}, not setting' + ) def test_update_config_wrong_type_fails(): @@ -120,7 +135,10 @@ def test_update_config_wrong_type_fails(): with pytest.raises(ValueError) as exceptinfo: tc.update_config(field='field1', value='should be int') - assert str(exceptinfo.value) == 'Cannot update config: wrong type for field field1' + assert ( + str(exceptinfo.value) + == 'Cannot update config: wrong type for field field1' + ) def test_validate_config_succeeds(): @@ -142,7 +160,10 @@ def test_validate_config_wrong_type_fails(): to_validate = {'field1': 'some string', 'field2': VALUE2, 'field3': VALUE3} with pytest.raises(ValueError) as exceptinfo: tc.validate_config(config=to_validate) - assert str(exceptinfo.value) == '"field1" has type , expected ' + assert ( + str(exceptinfo.value) + == '"field1" has type , expected ' + ) def test_validate_config_field_not_in_config_fails(): @@ -150,38 +171,64 @@ def test_validate_config_field_not_in_config_fails(): to_validate = {'field1': VALUE1, 'field2': VALUE2} with pytest.raises(ValueError) as exceptinfo: tc.validate_config(config=to_validate) - assert str(exceptinfo.value) == f'"field3" required but not in {to_validate}' + assert str(exceptinfo.value) == f'"field3" required but not in {to_validate}' def test_validate_config_field_not_in_format_fails(): tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) - to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3, 'field4': 'unexpected field'} + to_validate = { + 'field1': VALUE1, + 'field2': VALUE2, + 'field3': VALUE3, + 'field4': 'unexpected field', + } with pytest.raises(ValueError) as exceptinfo: tc.validate_config(config=to_validate) - assert str(exceptinfo.value) == f'unexpected key "field4" in config, expected {tc.config_format()}' + assert ( + str(exceptinfo.value) + == f'unexpected key "field4" in config, expected {tc.config_format()}' + ) def test_validate_config_nested_succeedss(): tc = TestConfigNestedConfigFormat() - to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': {'field3a': 'path', 'field3b': 42}} + to_validate = { + 'field1': VALUE1, + 'field2': VALUE2, + 'field3': {'field3a': 'path', 'field3b': 42}, + } tc.validate_config(config=to_validate) def test_validate_config_nested_dict_mismatch_fails(): tc = TestConfigNestedConfigFormat() - to_validate = {'field1': VALUE1, 'field2': {'field2a': 4.4}, 'field3': {'field3a': 'path', 'field3b': 42}} + to_validate = { + 'field1': VALUE1, + 'field2': {'field2a': 4.4}, + 'field3': {'field3a': 'path', 'field3b': 42}, + } with pytest.raises(ValueError) as exceptinfo: tc.validate_config(config=to_validate) - assert str(exceptinfo.value) == '"field2" has type , expected ' + assert ( + str(exceptinfo.value) + == '"field2" has type , expected ' + ) def test_validate_config_nested_config_not_in_format_fails(): tc = TestConfigNestedConfigFormat() - to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': {'field3a': 'path', 'field3b': 42, 'field3c': 'surplus'}} + to_validate = { + 'field1': VALUE1, + 'field2': VALUE2, + 'field3': {'field3a': 'path', 'field3b': 42, 'field3c': 'surplus'}, + } with pytest.raises(ValueError) as exceptinfo: tc.validate_config(config=to_validate) expected = tc.config_format()['field3'] - assert str(exceptinfo.value) == f'unexpected key "field3c" in config, expected {expected}' + assert ( + str(exceptinfo.value) + == f'unexpected key "field3c" in config, expected {expected}' + ) def test_validate_config_optional_set_succeeds(): @@ -202,13 +249,27 @@ def test_validate_config_optional_not_given_succeeds(): def test_validate_config_list_succeeds(): tc = TestConfigList() - to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': [{'field3a': 'path', 'field3b': 'another path'}, {'field3a': 'path2', 'field3b': 'another path2'}]} + to_validate = { + 'field1': VALUE1, + 'field2': VALUE2, + 'field3': [ + {'field3a': 'path', 'field3b': 'another path'}, + {'field3a': 'path2', 'field3b': 'another path2'}, + ], + } tc.validate_config(config=to_validate) def test_validate_config_list_wrong_type_fails(): tc = TestConfigList() - to_validate = {'field1': VALUE1, 'field2': VALUE2, 'field3': ['path', 'another path']} + to_validate = { + 'field1': VALUE1, + 'field2': VALUE2, + 'field3': ['path', 'another path'], + } with pytest.raises(ValueError) as exceptinfo: tc.validate_config(config=to_validate) - assert str(exceptinfo.value) == f'list entries for "field3" have type , expected ' + assert ( + str(exceptinfo.value) + == f'list entries for "field3" have type , expected ' + ) From 692a5019930f15c008984ea6bd2c650b2924ca1a Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 09:53:30 +0000 Subject: [PATCH 3/9] applied formatting --- ifsbench/config_mixin.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py index d74c675..e56d975 100644 --- a/ifsbench/config_mixin.py +++ b/ifsbench/config_mixin.py @@ -16,6 +16,20 @@ def _config_from_locals(config: dict[str, Any]) -> None: class ConfigMixin(ABC): + """ + Base class for handling configurations in a format that can be used for storage. + + The contents of the config are based on the parameters required by the implementing + classes constructor. Because of this, additional entries cannot be added to an existing config. + However, the values of individual entries can be updated with a value of the same type. + + The required format can be either created based on the constructor, or explicitly set by + implementing the `config_format` method. + + Parameters + ---- + config: dictionary containing parameter names and their values + """ _config = None From 7df9174689ac3530f3d47ca5b12eefe6f98695ef Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 11:40:01 +0000 Subject: [PATCH 4/9] add missing tests for config_mixin --- ifsbench/config_mixin.py | 2 -- tests/test_config_mixin.py | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py index e56d975..9accdc2 100644 --- a/ifsbench/config_mixin.py +++ b/ifsbench/config_mixin.py @@ -41,9 +41,7 @@ def config_format(cls) -> dict[str, type | dict]: @classmethod def _format_from_init(cls) -> dict[str, type | dict]: format = dict(get_type_hints(cls.__init__, include_extras=False)) - print(f'format initial={format}, type={type(format)}') format = _config_from_locals(format) - print(f'format cleaned: {format}') return format def set_config_from_init_locals(self, config: dict[str, Any]): diff --git a/tests/test_config_mixin.py b/tests/test_config_mixin.py index 3853036..ac7bf3f 100644 --- a/tests/test_config_mixin.py +++ b/tests/test_config_mixin.py @@ -75,6 +75,16 @@ def test_set_config_succeeds(): assert config == expected +def test_set_config_already_set_fails(): + + tc = TestConfigSet(field1=VALUE1, field2=VALUE2, field3=VALUE3) + config = {'field1': VALUE1, 'field2': VALUE2, 'field3': VALUE3} + + with pytest.raises(ValueError) as exceptinfo: + tc.set_config(config) + assert str(exceptinfo.value) == f'Config already set.' + + def test_set_config_from_init_locals_succeeds(): tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) @@ -273,3 +283,37 @@ def test_validate_config_list_wrong_type_fails(): str(exceptinfo.value) == f'list entries for "field3" have type , expected ' ) + + +def test_update_config_succeeds(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + config = tc.get_config().copy() + + tc.update_config('field1', 5) + + out_conf = tc.get_config() + config['field1'] = 5 + + assert out_conf == config + + +def test_update_config_new_field_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + + with pytest.raises(ValueError) as exceptinfo: + tc.update_config('field4', 5) + assert ( + str(exceptinfo.value) + == f'field4 not part of config {tc.get_config()}, not setting' + ) + + +def test_update_config_wrong_type_fails(): + tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) + + with pytest.raises(ValueError) as exceptinfo: + tc.update_config('field1', 3.3) + assert ( + str(exceptinfo.value) + == f'Cannot update config: wrong type for field field1' + ) From 88d0a3dd9cff3063d2cb869d88bc60aa364d2d49 Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 12:05:42 +0000 Subject: [PATCH 5/9] Cleanup for pylint --- ifsbench/config_mixin.py | 32 +++++++++--------- ifsbench/data/namelisthandler.py | 2 +- tests/test_config_mixin.py | 58 +++++++------------------------- 3 files changed, 28 insertions(+), 64 deletions(-) diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py index 9accdc2..e4b7ff4 100644 --- a/ifsbench/config_mixin.py +++ b/ifsbench/config_mixin.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, get_args, get_origin, get_type_hints, Optional, TypeVar, Union +from typing import Any, get_args, get_origin, get_type_hints, Union __all__ = ['ConfigMixin', 'CONF'] @@ -40,9 +40,9 @@ def config_format(cls) -> dict[str, type | dict]: @classmethod def _format_from_init(cls) -> dict[str, type | dict]: - format = dict(get_type_hints(cls.__init__, include_extras=False)) - format = _config_from_locals(format) - return format + format_definition = dict(get_type_hints(cls.__init__, include_extras=False)) + format_definition = _config_from_locals(format_definition) + return format_definition def set_config_from_init_locals(self, config: dict[str, Any]): config = _config_from_locals(config) @@ -59,7 +59,7 @@ def get_config(self) -> dict[str, CONF]: def update_config(self, field: str, value: CONF) -> None: if field not in self._config: raise ValueError(f'{field} not part of config {self._config}, not setting') - if type(value) != type(self._config[field]): + if not isinstance(value, type(self._config[field])): raise ValueError( f'Cannot update config: wrong type {type(value)} for field {field}' ) @@ -67,27 +67,25 @@ def update_config(self, field: str, value: CONF) -> None: @classmethod def validate_config(cls, config: dict[str, CONF]): - format = cls.config_format() - cls._validate_config_from_format(config, format) + format_definition = cls.config_format() + cls._validate_config_from_format(config, format_definition) @classmethod def _validate_config_from_format( - cls, config: dict[str, CONF], format: dict[str, type | dict] + cls, config: dict[str, CONF], format_definition: dict[str, type | dict] ): - print(f'config: {config}') - print(f'format: {format}') for key, value in config.items(): if not isinstance(value, CONF): # check that the given value is a valid config type raise ValueError(f'Unsupported config value type for {value}') - if key not in format: - raise ValueError(f'unexpected key "{key}" in config, expected {format}') + if key not in format_definition: + raise ValueError(f'unexpected key "{key}" in config, expected {format_definition}') - for key, value in format.items(): + for key, value in format_definition.items(): if (key not in config) and (type(None) not in get_args(value)): - # format key has to be in config unless it's optional + # format_definition key has to be in config unless it's optional raise ValueError(f'"{key}" required but not in {config}') if isinstance(value, dict): # nested, check that field also nested in config, then recursively check dict. @@ -95,14 +93,14 @@ def _validate_config_from_format( raise ValueError( f'"{key}" has type {type(config[key])}, expected {value}' ) - cls._validate_config_from_format(config[key], format[key]) + cls._validate_config_from_format(config[key], format_definition[key]) elif isinstance(value, list): # For now, only check both are lists and first entry type is correct, don't check every entry. if not isinstance(config[key], list): raise ValueError( f'"{key}" has type {type(config[key])}, expected {value}' ) - if type(value[0]) != type(config[key][0]): + if not isinstance(value[0], type(config[key][0])): raise ValueError( f'list entries for "{key}" have type {type(config[key][0])}, expected {type(value[0])}' ) @@ -113,7 +111,7 @@ def _validate_config_from_format( raise ValueError( f'wrong type for optional {type(value)}: {config[key]}' ) - elif type(config[key]) != value: + elif not isinstance(config[key], value): # types of format and config have to match raise ValueError( f'"{key}" has type {type(config[key])}, expected {value}' diff --git a/ifsbench/data/namelisthandler.py b/ifsbench/data/namelisthandler.py index c674660..422b564 100644 --- a/ifsbench/data/namelisthandler.py +++ b/ifsbench/data/namelisthandler.py @@ -7,7 +7,7 @@ from enum import auto, StrEnum import pathlib -from typing import Optional, Self, Union +from typing import Self import f90nml diff --git a/tests/test_config_mixin.py b/tests/test_config_mixin.py index ac7bf3f..af4ed0c 100644 --- a/tests/test_config_mixin.py +++ b/tests/test_config_mixin.py @@ -1,12 +1,17 @@ -import pytest from typing import Optional +import pytest + from ifsbench import ConfigMixin class TestConfigFromLocals(ConfigMixin): def __init__(self, field1: int, field2: float, field3: str): self.set_config_from_init_locals(locals()) + # delete fields so pylint doesn't complain. + del field1 + del field2 + del field3 @classmethod def config_format(cls): @@ -39,6 +44,8 @@ def config_format(cls): class TestConfigOptional(ConfigMixin): def __init__(self, field1: int, field2: Optional[str] = None): self.set_config_from_init_locals(locals()) + del field1 + del field2 @classmethod def config_format(cls): @@ -82,7 +89,7 @@ def test_set_config_already_set_fails(): with pytest.raises(ValueError) as exceptinfo: tc.set_config(config) - assert str(exceptinfo.value) == f'Config already set.' + assert str(exceptinfo.value) == 'Config already set.' def test_set_config_from_init_locals_succeeds(): @@ -112,13 +119,6 @@ def test_set_config_from_init_optional_none_succeeds(): assert config == expected -def test_set_config_already_set_fails(): - - tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) - with pytest.raises(ValueError): - tc.set_config({'something': 'other'}) - - def test_update_config_succeeds(): tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) @@ -144,10 +144,10 @@ def test_update_config_wrong_type_fails(): tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) with pytest.raises(ValueError) as exceptinfo: - tc.update_config(field='field1', value='should be int') + tc.update_config('field1', 3.3) assert ( str(exceptinfo.value) - == 'Cannot update config: wrong type for field field1' + == 'Cannot update config: wrong type for field field1' ) @@ -281,39 +281,5 @@ def test_validate_config_list_wrong_type_fails(): tc.validate_config(config=to_validate) assert ( str(exceptinfo.value) - == f'list entries for "field3" have type , expected ' - ) - - -def test_update_config_succeeds(): - tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) - config = tc.get_config().copy() - - tc.update_config('field1', 5) - - out_conf = tc.get_config() - config['field1'] = 5 - - assert out_conf == config - - -def test_update_config_new_field_fails(): - tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) - - with pytest.raises(ValueError) as exceptinfo: - tc.update_config('field4', 5) - assert ( - str(exceptinfo.value) - == f'field4 not part of config {tc.get_config()}, not setting' - ) - - -def test_update_config_wrong_type_fails(): - tc = TestConfigFromLocals(field1=VALUE1, field2=VALUE2, field3=VALUE3) - - with pytest.raises(ValueError) as exceptinfo: - tc.update_config('field1', 3.3) - assert ( - str(exceptinfo.value) - == f'Cannot update config: wrong type for field field1' + == 'list entries for "field3" have type , expected ' ) From 4e3dd4cce0bc6d8adf2632ad6f269a730f076268 Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 12:19:31 +0000 Subject: [PATCH 6/9] Trying to fix type error on 3.8 --- ifsbench/config_mixin.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py index e4b7ff4..ed058f7 100644 --- a/ifsbench/config_mixin.py +++ b/ifsbench/config_mixin.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any, get_args, get_origin, get_type_hints, Union +from typing import Any, Dict, get_args, get_origin, get_type_hints, List, Union __all__ = ['ConfigMixin', 'CONF'] -CONF = Union[int, float, str, bool, dict, list, None] +CONF = Union[int, float, str, bool, Dict, List, None] -def _config_from_locals(config: dict[str, Any]) -> None: +def _config_from_locals(config: Dict[str, Any]) -> None: print(f'from locals: config={config}, type={type(config)}') config = config.copy() config.pop('self', None) @@ -35,25 +35,25 @@ class ConfigMixin(ABC): @classmethod @abstractmethod - def config_format(cls) -> dict[str, type | dict]: + def config_format(cls) -> Dict[str, type | Dict]: raise NotImplementedError() @classmethod - def _format_from_init(cls) -> dict[str, type | dict]: + def _format_from_init(cls) -> Dict[str, type | Dict]: format_definition = dict(get_type_hints(cls.__init__, include_extras=False)) format_definition = _config_from_locals(format_definition) return format_definition - def set_config_from_init_locals(self, config: dict[str, Any]): + def set_config_from_init_locals(self, config: Dict[str, Any]): config = _config_from_locals(config) self.set_config(config) - def set_config(self, config: dict[str, CONF]) -> None: + def set_config(self, config: Dict[str, CONF]) -> None: if self._config: raise ValueError('Config already set.') self._config = config - def get_config(self) -> dict[str, CONF]: + def get_config(self) -> Dict[str, CONF]: return self._config def update_config(self, field: str, value: CONF) -> None: @@ -66,13 +66,13 @@ def update_config(self, field: str, value: CONF) -> None: self._config[field] = value @classmethod - def validate_config(cls, config: dict[str, CONF]): + def validate_config(cls, config: Dict[str, CONF]): format_definition = cls.config_format() cls._validate_config_from_format(config, format_definition) @classmethod def _validate_config_from_format( - cls, config: dict[str, CONF], format_definition: dict[str, type | dict] + cls, config: Dict[str, CONF], format_definition: Dict[str, type | Dict] ): for key, value in config.items(): @@ -87,14 +87,14 @@ def _validate_config_from_format( if (key not in config) and (type(None) not in get_args(value)): # format_definition key has to be in config unless it's optional raise ValueError(f'"{key}" required but not in {config}') - if isinstance(value, dict): + if isinstance(value, Dict): # nested, check that field also nested in config, then recursively check dict. - if not isinstance(config[key], dict): + if not isinstance(config[key], Dict): raise ValueError( f'"{key}" has type {type(config[key])}, expected {value}' ) cls._validate_config_from_format(config[key], format_definition[key]) - elif isinstance(value, list): + elif isinstance(value, List): # For now, only check both are lists and first entry type is correct, don't check every entry. if not isinstance(config[key], list): raise ValueError( From e292b088fd35e88f12e1fc4eb270abc63214f649 Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 13:40:48 +0000 Subject: [PATCH 7/9] Trying to fix type error on 3.8 --- ifsbench/config_mixin.py | 6 +++--- ifsbench/data/extracthandler.py | 4 ++-- ifsbench/data/namelisthandler.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py index ed058f7..881b6cd 100644 --- a/ifsbench/config_mixin.py +++ b/ifsbench/config_mixin.py @@ -35,11 +35,11 @@ class ConfigMixin(ABC): @classmethod @abstractmethod - def config_format(cls) -> Dict[str, type | Dict]: + def config_format(cls) -> Dict[str, Any]: raise NotImplementedError() @classmethod - def _format_from_init(cls) -> Dict[str, type | Dict]: + def _format_from_init(cls) -> Dict[str, Any]: format_definition = dict(get_type_hints(cls.__init__, include_extras=False)) format_definition = _config_from_locals(format_definition) return format_definition @@ -72,7 +72,7 @@ def validate_config(cls, config: Dict[str, CONF]): @classmethod def _validate_config_from_format( - cls, config: Dict[str, CONF], format_definition: Dict[str, type | Dict] + cls, config: Dict[str, CONF], format_definition: Dict[str, Any] ): for key, value in config.items(): diff --git a/ifsbench/data/extracthandler.py b/ifsbench/data/extracthandler.py index 6048d0e..b626625 100644 --- a/ifsbench/data/extracthandler.py +++ b/ifsbench/data/extracthandler.py @@ -7,7 +7,7 @@ import pathlib import shutil -from typing import Optional, Self +from typing import Any, Dict, Optional, Self from ifsbench.config_mixin import CONF, ConfigMixin from ifsbench.data.datahandler import DataHandler @@ -42,7 +42,7 @@ def __init__(self, archive_path: str, target_dir: Optional[str] = None): self._target_dir = pathlib.Path(target_dir) @classmethod - def config_format(cls): + def config_format(cls) -> Dict[str, Any]: return cls._format_from_init() @classmethod diff --git a/ifsbench/data/namelisthandler.py b/ifsbench/data/namelisthandler.py index 422b564..5b98a7a 100644 --- a/ifsbench/data/namelisthandler.py +++ b/ifsbench/data/namelisthandler.py @@ -7,7 +7,7 @@ from enum import auto, StrEnum import pathlib -from typing import Self +from typing import Any, Dict, Self import f90nml @@ -80,13 +80,13 @@ def from_keystring( return cls(keys[0], keys[1], mode, value) @classmethod - def from_config(cls, config: dict[str, CONF]): + def from_config(cls, config: dict[str, CONF]) -> Self: cls.validate_config(config) value = config['value'] if 'value' in config else None return cls(config['namelist'], config['entry'], config['mode'], value) @classmethod - def config_format(cls): + def config_format(cls) -> Dict[str, Any]: return cls._format_from_init() def apply(self, namelist): @@ -184,7 +184,7 @@ def __init__( raise ValueError("Namelist overrides must be NamelistOverride objects!") @classmethod - def config_format(cls) -> dict[str, type | dict]: + def config_format(cls) -> dict[str, Any]: return { 'input_path': str, 'output_path': str, From e584677c419519cd02f217d83af2e78210ac603b Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Mon, 27 Jan 2025 15:09:20 +0000 Subject: [PATCH 8/9] remove typing.Self which was only added in 3.11 --- ifsbench/data/extracthandler.py | 4 ++-- ifsbench/data/namelisthandler.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ifsbench/data/extracthandler.py b/ifsbench/data/extracthandler.py index b626625..fff812b 100644 --- a/ifsbench/data/extracthandler.py +++ b/ifsbench/data/extracthandler.py @@ -7,7 +7,7 @@ import pathlib import shutil -from typing import Any, Dict, Optional, Self +from typing import Any, Dict, Optional from ifsbench.config_mixin import CONF, ConfigMixin from ifsbench.data.datahandler import DataHandler @@ -46,7 +46,7 @@ def config_format(cls) -> Dict[str, Any]: return cls._format_from_init() @classmethod - def from_config(cls, config: dict[str, CONF]) -> Self: + def from_config(cls, config: dict[str, CONF]) -> 'ExtractHandler': cls.validate_config(config) archive_path = config['archive_path'] target_dir = config['target_dir'] if 'target_dir' in config else None diff --git a/ifsbench/data/namelisthandler.py b/ifsbench/data/namelisthandler.py index 5b98a7a..ec88af8 100644 --- a/ifsbench/data/namelisthandler.py +++ b/ifsbench/data/namelisthandler.py @@ -7,7 +7,7 @@ from enum import auto, StrEnum import pathlib -from typing import Any, Dict, Self +from typing import Any, Dict import f90nml @@ -63,7 +63,7 @@ def __init__( @classmethod def from_keytuple( cls, key: tuple[str, str], mode: NamelistOperation, value: CONF = None - ) -> Self: + ) -> 'NamelistOverride': if len(key) != 2: raise ValueError(f"The key tuple must be of length two, found key {key}.") return cls(key[0], key[1], mode, value) @@ -71,7 +71,7 @@ def from_keytuple( @classmethod def from_keystring( cls, key: str, mode: NamelistOperation, value: CONF = None - ) -> Self: + ) -> 'NamelistOverride': keys = key.split('/') if len(keys) != 2: raise ValueError( @@ -80,7 +80,7 @@ def from_keystring( return cls(keys[0], keys[1], mode, value) @classmethod - def from_config(cls, config: dict[str, CONF]) -> Self: + def from_config(cls, config: dict[str, CONF]) -> 'NamelistOverride': cls.validate_config(config) value = config['value'] if 'value' in config else None return cls(config['namelist'], config['entry'], config['mode'], value) @@ -196,7 +196,7 @@ def config_format(cls) -> dict[str, Any]: } @classmethod - def from_config(cls, config: dict[str, CONF]) -> Self: + def from_config(cls, config: dict[str, CONF]) -> 'NamelistHandler': cls.validate_config(config) input_path = config['input_path'] output_path = config['output_path'] From 74bb56b23674bf96b33355d1c8804b7ce4d6e517 Mon Sep 17 00:00:00 2001 From: Ulrike Hager Date: Tue, 28 Jan 2025 09:53:09 +0000 Subject: [PATCH 9/9] Rename _config to avoid potential collisions. Remove forgotten debug print --- ifsbench/config_mixin.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/ifsbench/config_mixin.py b/ifsbench/config_mixin.py index 881b6cd..cfd2471 100644 --- a/ifsbench/config_mixin.py +++ b/ifsbench/config_mixin.py @@ -8,7 +8,6 @@ def _config_from_locals(config: Dict[str, Any]) -> None: - print(f'from locals: config={config}, type={type(config)}') config = config.copy() config.pop('self', None) config.pop('cls', None) @@ -31,7 +30,7 @@ class ConfigMixin(ABC): config: dictionary containing parameter names and their values """ - _config = None + _mixin_config = None @classmethod @abstractmethod @@ -49,21 +48,21 @@ def set_config_from_init_locals(self, config: Dict[str, Any]): self.set_config(config) def set_config(self, config: Dict[str, CONF]) -> None: - if self._config: + if self._mixin_config: raise ValueError('Config already set.') - self._config = config + self._mixin_config = config def get_config(self) -> Dict[str, CONF]: - return self._config + return self._mixin_config def update_config(self, field: str, value: CONF) -> None: - if field not in self._config: - raise ValueError(f'{field} not part of config {self._config}, not setting') - if not isinstance(value, type(self._config[field])): + if field not in self._mixin_config: + raise ValueError(f'{field} not part of config {self._mixin_config}, not setting') + if not isinstance(value, type(self._mixin_config[field])): raise ValueError( f'Cannot update config: wrong type {type(value)} for field {field}' ) - self._config[field] = value + self._mixin_config[field] = value @classmethod def validate_config(cls, config: Dict[str, CONF]):