diff --git a/pyproject.toml b/pyproject.toml index 17d4c55..0d604ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,8 @@ dependencies = [ "requests", "pytest >=8.0.1", "ruamel.yaml >=0.18.5", - "jsonschema >=4.21.1" + "jsonschema >=4.21.1", + "payu >=1.1.3" ] [project.optional-dependencies] diff --git a/src/model_config_tests/models/__init__.py b/src/model_config_tests/models/__init__.py index efb7367..c704378 100644 --- a/src/model_config_tests/models/__init__.py +++ b/src/model_config_tests/models/__init__.py @@ -1,3 +1,4 @@ from model_config_tests.models.accessom2 import AccessOm2 +from model_config_tests.models.accessom3 import AccessOm3 -index = {"access-om2": AccessOm2} +index = {"access-om2": AccessOm2, "access-om3": AccessOm3} diff --git a/src/model_config_tests/models/accessom2.py b/src/model_config_tests/models/accessom2.py index 0a98800..a1cc48a 100644 --- a/src/model_config_tests/models/accessom2.py +++ b/src/model_config_tests/models/accessom2.py @@ -7,13 +7,7 @@ import f90nml -from model_config_tests.models.model import Model - -BASE_SCHEMA_URL = "https://raw.githubusercontent.com/ACCESS-NRI/schema/main/au.org.access-nri/model/access-om2/experiment/reproducibility/checksums" - -SCHEMA_VERSION_1_0_0 = "1-0-0" -DEFAULT_SCHEMA_VERSION = SCHEMA_VERSION_1_0_0 -SUPPORTED_SCHEMA_VERSIONS = [SCHEMA_VERSION_1_0_0] +from model_config_tests.models.model import SCHEMA_VERSION_1_0_0, Model class AccessOm2(Model): @@ -23,7 +17,6 @@ def __init__(self, experiment): self.accessom2_config = experiment.control_path / "accessom2.nml" self.ocean_config = experiment.control_path / "ocean" / "input.nml" - self.default_schema_version = DEFAULT_SCHEMA_VERSION def set_model_runtime(self, years: int = 0, months: int = 0, seconds: int = 10800): """Set config files to a short time period for experiment run. @@ -31,6 +24,13 @@ def set_model_runtime(self, years: int = 0, months: int = 0, seconds: int = 1080 with open(self.accessom2_config) as f: nml = f90nml.read(f) + # Check that two of years, months, seconds is zero + if sum(x == 0 for x in (years, months, seconds)) != 2: + raise NotImplementedError( + "Cannot specify runtime in seconds and years and months" + + " at the same time. Two of which must be zero" + ) + nml["date_manager_nml"]["restart_period"] = [years, months, seconds] nml.write(self.accessom2_config, force=True) @@ -75,7 +75,7 @@ def extract_checksums( output_checksums[field].append(checksum) if schema_version is None: - schema_version = DEFAULT_SCHEMA_VERSION + schema_version = self.default_schema_version if schema_version == SCHEMA_VERSION_1_0_0: checksums = { @@ -88,30 +88,3 @@ def extract_checksums( ) return checksums - - def check_checksums_over_restarts( - self, - long_run_checksum: dict[str, Any], - short_run_checksum_0: dict[str, Any], - short_run_checksum_1: dict[str, Any], - ) -> bool: - """Compare a checksums from a long run (e.g. 2 days) against - checksums from 2 short runs (e.g. 1 day)""" - short_run_checksums = short_run_checksum_0["output"] - for field, checksums in short_run_checksum_1["output"].items(): - if field not in short_run_checksums: - short_run_checksums[field] = checksums - else: - short_run_checksums[field].extend(checksums) - - matching_checksums = True - for field, checksums in long_run_checksum["output"].items(): - for checksum in checksums: - if ( - field not in short_run_checksums - or checksum not in short_run_checksums[field] - ): - print(f"Unequal checksum: {field}: {checksum}") - matching_checksums = False - - return matching_checksums diff --git a/src/model_config_tests/models/accessom3.py b/src/model_config_tests/models/accessom3.py new file mode 100644 index 0000000..9a5df35 --- /dev/null +++ b/src/model_config_tests/models/accessom3.py @@ -0,0 +1,90 @@ +"""Specific Access-OM3 Model setup and post-processing""" + +import re +from collections import defaultdict +from pathlib import Path +from typing import Any + +from payu.models.cesm_cmeps import Runconfig + +from model_config_tests.models.model import SCHEMA_VERSION_1_0_0, Model + + +class AccessOm3(Model): + def __init__(self, experiment): + super().__init__(experiment) + self.output_file = self.experiment.output000 / "ocean.stats" + + self.runconfig = experiment.control_path / "nuopc.runconfig" + self.ocean_config = experiment.control_path / "input.nml" + + def set_model_runtime(self, years: int = 0, months: int = 0, seconds: int = 10800): + """Set config files to a short time period for experiment run. + Default is 3 hours""" + runconfig = Runconfig(self.runconfig) + + if years == months == 0: + freq = "nseconds" + n = str(seconds) + elif seconds == 0: + freq = "nmonths" + n = str(12 * years + months) + else: + raise NotImplementedError( + "Cannot specify runtime in seconds and year/months at the same time" + ) + + runconfig.set("CLOCK_attributes", "restart_n", n) + runconfig.set("CLOCK_attributes", "restart_option", freq) + runconfig.set("CLOCK_attributes", "stop_n", n) + runconfig.set("CLOCK_attributes", "stop_option", freq) + + runconfig.write() + + def output_exists(self) -> bool: + """Check for existing output file""" + return self.output_file.exists() + + def extract_checksums( + self, output_directory: Path = None, schema_version: str = None + ) -> dict[str, Any]: + """Parse output file and create checksum using defined schema""" + if output_directory: + output_filename = output_directory / "ocean.stats" + else: + output_filename = self.output_file + + # ocean.stats is used for regression testing in MOM6's own test suite + # See https://github.com/mom-ocean/MOM6/blob/2ab885eddfc47fc0c8c0bae46bc61531104428d5/.testing/Makefile#L495-L501 + # Rows in ocean.stats look like: + # 0, 693135.000, 0, En 3.0745627134675957E-23, CFL 0.00000, ... + # where the first three columns are Step, Day, Truncs and the remaining + # columns include a label for what they are (e.g. En = Energy/Mass) + # Header info is only included for new runs so can't be relied on + output_checksums: dict[str, list[any]] = defaultdict(list) + + with open(output_filename) as f: + lines = f.readlines() + # Skip header if it exists (for new runs) + istart = 2 if "Step" in lines[0] else 0 + for line in lines[istart:]: + for col in line.split(","): + # Only keep columns with labels (ie not Step, Day, Truncs) + col = re.split(" +", col.strip().rstrip("\n")) + if len(col) > 1: + output_checksums[col[0]].append(col[-1]) + + if schema_version is None: + schema_version = self.default_schema_version + + if schema_version == SCHEMA_VERSION_1_0_0: + checksums = { + "schema_version": schema_version, + "output": dict(output_checksums), + } + else: + raise NotImplementedError( + f"Unsupported checksum schema version: {schema_version}" + ) + + return checksums diff --git a/src/model_config_tests/models/model.py b/src/model_config_tests/models/model.py index ee5bedd..05f818f 100644 --- a/src/model_config_tests/models/model.py +++ b/src/model_config_tests/models/model.py @@ -2,11 +2,19 @@ from pathlib import Path +SCHEMA_VERSION_1_0_0 = "1-0-0" +SCHEMA_1_0_0_URL = "https://raw.githubusercontent.com/ACCESS-NRI/schema/7666d95967de4dfd19b0d271f167fdcfd3f46962/au.org.access-nri/model/reproducibility/checksums/1-0-0.json" +SCHEMA_VERSION_TO_URL = {SCHEMA_VERSION_1_0_0: SCHEMA_1_0_0_URL} +DEFAULT_SCHEMA_VERSION = "1-0-0" + class Model: def __init__(self, experiment): self.experiment = experiment + self.default_schema_version = DEFAULT_SCHEMA_VERSION + self.schema_version_to_url = SCHEMA_VERSION_TO_URL + def extract_checksums(self, output_directory: Path, schema_version: str): """Extract checksums from output directory""" raise NotImplementedError @@ -24,4 +32,21 @@ def check_checksums_over_restarts( ) -> bool: """Compare a checksums from a long run (e.g. 2 days) against checksums from 2 short runs (e.g. 1 day)""" - raise NotImplementedError + short_run_checksums = short_run_checksum_0["output"] + for field, checksums in short_run_checksum_1["output"].items(): + if field not in short_run_checksums: + short_run_checksums[field] = checksums + else: + short_run_checksums[field].extend(checksums) + + matching_checksums = True + for field, checksums in long_run_checksum["output"].items(): + for checksum in checksums: + if ( + field not in short_run_checksums + or checksum not in short_run_checksums[field] + ): + print(f"Unequal checksum: {field}: {checksum}") + matching_checksums = False + + return matching_checksums diff --git a/tests/resources/access-om2-checksums-1-0-0.json b/tests/resources/access-om2/checksums/1-0-0.json similarity index 100% rename from tests/resources/access-om2-checksums-1-0-0.json rename to tests/resources/access-om2/checksums/1-0-0.json diff --git a/tests/resources/access-om2.out b/tests/resources/access-om2/output000/access-om2.out similarity index 100% rename from tests/resources/access-om2.out rename to tests/resources/access-om2/output000/access-om2.out diff --git a/tests/resources/access-om3/checksums/1-0-0.json b/tests/resources/access-om3/checksums/1-0-0.json new file mode 100644 index 0000000..61eeb82 --- /dev/null +++ b/tests/resources/access-om3/checksums/1-0-0.json @@ -0,0 +1,32 @@ +{ + "schema_version": "1-0-0", + "output": { + "En": [ + "3.0745627134675957E-23" + ], + "CFL": [ + "0.00000" + ], + "SL": [ + "1.5112E-10" + ], + "M": [ + "1.36404E+21" + ], + "S": [ + "34.7263" + ], + "T": [ + "3.6362" + ], + "Me": [ + "0.00E+00" + ], + "Se": [ + "0.00E+00" + ], + "Te": [ + "0.00E+00" + ] + } +} diff --git a/tests/resources/access-om3/output000/ocean.stats b/tests/resources/access-om3/output000/ocean.stats new file mode 100644 index 0000000..ee9e004 --- /dev/null +++ b/tests/resources/access-om3/output000/ocean.stats @@ -0,0 +1,3 @@ + Step, Day, Truncs, Energy/Mass, Maximum CFL, Mean Sea Level, Total Mass, Mean Salin, Mean Temp, Frac Mass Err, Salin Err, Temp Err + [days] [m2 s-2] [Nondim] [m] [kg] [PSU] [degC] [Nondim] [PSU] [degC] + 0, 693135.000, 0, En 3.0745627134675957E-23, CFL 0.00000, SL 1.5112E-10, M 1.36404E+21, S 34.7263, T 3.6362, Me 0.00E+00, Se 0.00E+00, Te 0.00E+00 diff --git a/tests/test_access_om2_extract_checksums.py b/tests/test_access_om2_extract_checksums.py deleted file mode 100644 index 1f7088c..0000000 --- a/tests/test_access_om2_extract_checksums.py +++ /dev/null @@ -1,43 +0,0 @@ -import json -from pathlib import Path -from unittest.mock import Mock - -import pytest -import requests - -from model_config_tests.models import AccessOm2 -from model_config_tests.models.accessom2 import SUPPORTED_SCHEMA_VERSIONS - - -@pytest.mark.parametrize("version", SUPPORTED_SCHEMA_VERSIONS) -def test_extract_checksums(version): - # Mock ExpTestHelper - mock_experiment = Mock() - mock_experiment.output000 = Path("tests/resources") - mock_experiment.control_path = Path("tests/tmp") - - model = AccessOm2(mock_experiment) - - checksums = model.extract_checksums(schema_version=version) - - # Assert version is set as expected - assert checksums["schema_version"] == version - - # Check the entire checksum file is expected - with open("tests/resources/access-om2-checksums-1-0-0.json") as file: - expected_checksums = json.load(file) - - assert checksums == expected_checksums - - # Validate checksum file with schema - # schema = get_schema_from_url(expected_checksums["schema"]) - - # Validate checksums against schema - # jsonschema.validate(instance=checksums, schema=schema) - - -def get_schema_from_url(url): - """Retrieve schema from github""" - response = requests.get(url) - assert response.status_code == 200 - return response.json() diff --git a/tests/test_model_extract_checksums.py b/tests/test_model_extract_checksums.py new file mode 100644 index 0000000..167e800 --- /dev/null +++ b/tests/test_model_extract_checksums.py @@ -0,0 +1,73 @@ +import json +import os +from pathlib import Path +from unittest.mock import Mock + +import jsonschema +import pytest +import requests + +from model_config_tests.models import index as model_index + +MODEL_NAMES = model_index.keys() +HERE = os.path.dirname(__file__) +RESOURCES_DIR = Path(f"{HERE}/resources") + + +@pytest.mark.parametrize("model_name", MODEL_NAMES) +def test_extract_checksums(model_name): + resources_dir = RESOURCES_DIR / model_name + + # Mock ExpTestHelper + mock_experiment = Mock() + mock_experiment.output000 = resources_dir / "output000" + mock_experiment.control_path = Path("test/tmp") + + # Create Model instance + ModelType = model_index[model_name] + model = ModelType(mock_experiment) + + # Test extract checksums for each schema version + for version, url in model.schema_version_to_url.items(): + checksums = model.extract_checksums(schema_version=version) + + # Assert version is set as expected + assert checksums["schema_version"] == version + + # Check the entire checksum file is expected + checksum_file = resources_dir / "checksums" / f"{version}.json" + with open(checksum_file) as file: + expected_checksums = json.load(file) + + assert checksums == expected_checksums + + # Validate checksum file with schema + schema = get_schema_from_url(url) + + # Validate checksums against schema + jsonschema.validate(instance=checksums, schema=schema) + + +@pytest.mark.parametrize("model_name", MODEL_NAMES) +def test_extract_checksums_unsupported_version(model_name): + resources_dir = RESOURCES_DIR / model_name + + # Mock ExpTestHelper + mock_experiment = Mock() + mock_experiment.output000 = resources_dir / "output000" + mock_experiment.control_path = Path("test/tmp") + + # Create Model instance + ModelType = model_index[model_name] + model = ModelType(mock_experiment) + + # Test NotImplementedError gets raised for unsupported versions + with pytest.raises(NotImplementedError): + model.extract_checksums(schema_version="test-version") + + +def get_schema_from_url(url): + """Retrieve schema from GitHub""" + response = requests.get(url) + assert response.status_code == 200 + return response.json()