From 9e454e07557d91df49d05ef915f0d2078ad11a00 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 6 Feb 2025 12:32:03 +0000 Subject: [PATCH] feat: allow using json files to define datasets --- src/anemoi/datasets/data/misc.py | 32 ++++++++++++++++-- src/anemoi/datasets/data/stores.py | 52 ++++++++++++++++++------------ tests/test_data.py | 5 +++ 3 files changed, 66 insertions(+), 23 deletions(-) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index a82d6bfa..844cd60b 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -10,7 +10,9 @@ import calendar import datetime +import json import logging +import os from pathlib import PurePath import numpy as np @@ -195,7 +197,6 @@ def _concat_or_join(datasets, kwargs): def _open(a): from .stores import Zarr - from .stores import zarr_lookup if isinstance(a, Dataset): return a.mutate() @@ -204,7 +205,34 @@ def _open(a): return Zarr(a).mutate() if isinstance(a, str): - return Zarr(zarr_lookup(a)).mutate() + from .stores import DATASET_FINDER + + tried = [] + for name in DATASET_FINDER.ls(a): + tried.append(name) + + if name.endswith(".json"): + DATASET_FINDER.log_open(a, name) + if not os.path.exists(name): + continue + + obj = json.load(open(name)) + if isinstance(obj, dict): + return _open_dataset(**obj).mutate() + elif isinstance(obj, (list, tuple)): + return _open_dataset(*obj).mutate() + raise ValueError(f"Invalid content: {type(obj)} in {name}") + + if name.endswith(".zarr") or name.endswith(".zip"): + try: + DATASET_FINDER.log_open(a, name) + return Zarr(name).mutate() + except zarr.errors.PathNotFoundError: + pass + + raise ValueError(f"Unsupported file: {name}") + + raise ValueError(f"Cannot find a dataset that matched '{a}'. Tried: {tried}") if isinstance(a, PurePath): return _open(str(a)).mutate() diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index c8340e6c..772e5a3b 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -16,6 +16,7 @@ import numpy as np import zarr +from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta from . import MissingDateError @@ -289,8 +290,6 @@ def statistics_tendencies(self, delta=None): delta = self.frequency if isinstance(delta, int): delta = f"{delta}h" - from anemoi.utils.dates import frequency_to_string - from anemoi.utils.dates import frequency_to_timedelta delta = frequency_to_timedelta(delta) delta = frequency_to_string(delta) @@ -450,36 +449,47 @@ def label(self): return "zarr*" -QUIET = set() +class DatasetFinder: + QUIET = set() + @cached_property + def _config(self): + return load_config()["datasets"] -def zarr_lookup(name, fail=True): + def ls(self, name): + if name in self._config["named"]: + yield self._config["named"][name] + return + + if name.endswith(".zip") or name.endswith(".zarr") or name.endswith(".json"): + yield name + return - if name.endswith(".zarr") or name.endswith(".zip"): - return name + for location in self._config["path"]: + if not location.endswith("/"): + location += "/" - config = load_config()["datasets"] + yield location + name + ".json" + yield location + name + ".zarr" - if name in config["named"]: - if name not in QUIET: - LOG.info("Opening `%s` as `%s`", name, config["named"][name]) - QUIET.add(name) - return config["named"][name] + def log_open(self, name, full): + if name not in self.QUIET: + LOG.info("Opening `%s` as `%s`", name, full) + self.QUIET.add(name) + + +DATASET_FINDER = DatasetFinder() + + +def zarr_lookup(name, fail=True): tried = [] - for location in config["path"]: - if not location.endswith("/"): - location += "/" - full = location + name + ".zarr" + for full in DATASET_FINDER.ls(name): tried.append(full) try: + DATASET_FINDER.log_open(name, full) z = open_zarr(full, dont_fail=True) if z is not None: - # Cache for next time - config["named"][name] = full - if name not in QUIET: - LOG.info("Opening `%s` as `%s`", name, full) - QUIET.add(name) return full except zarr.errors.PathNotFoundError: pass diff --git a/tests/test_data.py b/tests/test_data.py index 2ab56ed0..9ee24c40 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -9,6 +9,7 @@ import datetime +import os from functools import cache from functools import wraps from unittest.mock import patch @@ -156,6 +157,10 @@ def create_zarr( def zarr_from_str(name, mode): # Format: test-2021-2021-6h-o96-abcd-0 + if name.endswith(".zarr"): + name = os.path.basename(name) + name = os.path.splitext(name)[0] + args = dict( test="test", start=2021,