From 40bdac2786f7a45eeaa9a36b7a501140037e79af Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 22 Jul 2024 14:51:22 +0000 Subject: [PATCH] simplify api, and make more robust --- docs/cli/datasets.rst | 1 - src/anemoi/registry/__init__.py | 21 +++++++++++++++++++++ src/anemoi/registry/commands/base.py | 5 ++++- src/anemoi/registry/commands/experiments.py | 4 ++-- src/anemoi/registry/entry/dataset.py | 14 +++++++++++++- src/anemoi/registry/entry/experiment.py | 15 ++++++++++++++- src/anemoi/registry/entry/weights.py | 15 ++++++++++++++- 7 files changed, 68 insertions(+), 7 deletions(-) diff --git a/docs/cli/datasets.rst b/docs/cli/datasets.rst index 7511501..f43dcfc 100644 --- a/docs/cli/datasets.rst +++ b/docs/cli/datasets.rst @@ -2,7 +2,6 @@ datasets ======== - .. argparse:: :module: anemoi.registry.__main__ :func: create_parser diff --git a/src/anemoi/registry/__init__.py b/src/anemoi/registry/__init__.py index b5c5815..763ed82 100644 --- a/src/anemoi/registry/__init__.py +++ b/src/anemoi/registry/__init__.py @@ -20,3 +20,24 @@ def config(): default_config = os.path.join(os.path.dirname(__file__), "config.yaml") config = load_config(secrets=["api_token"], defaults=default_config) return config.get("registry") + + +from .entry.dataset import DatasetCatalogueEntry as Dataset +from .entry.dataset import DatasetCatalogueEntryList as DatasetsList +from .entry.experiment import ExperimentCatalogueEntry as Experiment +from .entry.experiment import ExperimentCatalogueEntryList as ExperimentsList +from .entry.weights import WeightCatalogueEntry as Weights +from .entry.weights import WeightsCatalogueEntryList as WeightsList +from .tasks import TaskCatalogueEntry as Task +from .tasks import TaskCatalogueEntryList as TasksList + +__all__ = [ + "Weights", + "WeightsList", + "Experiment", + "ExperimentsList", + "Dataset", + "DatasetsList", + "Task", + "TasksList", +] diff --git a/src/anemoi/registry/commands/base.py b/src/anemoi/registry/commands/base.py index 4564907..0260d3e 100644 --- a/src/anemoi/registry/commands/base.py +++ b/src/anemoi/registry/commands/base.py @@ -35,7 +35,7 @@ def is_identifier(self, name_or_path): except CatalogueEntryNotFound: return False - def process_task(self, entry, args, k, func_name=None, /, **kwargs): + def process_task(self, entry, args, k, func_name=None, /, _skip_if_not_found=False, **kwargs): """ Call the method `k` on the entry object. The args/kwargs given to the method are extracted from from the argument `k` in the `args` object. @@ -46,6 +46,9 @@ def process_task(self, entry, args, k, func_name=None, /, **kwargs): The provided **kwargs are also passed to the method. The method name can be changed by providing the `func_name` argument. """ + if entry is None and _skip_if_not_found: + LOG.warning(f"Cannot find entry {args.NAME_OR_PATH}. Skipping {k}.") + return assert isinstance(k, str), k if func_name is None: diff --git a/src/anemoi/registry/commands/experiments.py b/src/anemoi/registry/commands/experiments.py index 73aa59e..5c5559f 100644 --- a/src/anemoi/registry/commands/experiments.py +++ b/src/anemoi/registry/commands/experiments.py @@ -86,8 +86,8 @@ def is_path(self, name_or_path): return True def _run(self, entry, args): - self.process_task(entry, args, "delete_artefacts") - self.process_task(entry, args, "unregister") + self.process_task(entry, args, "delete_artefacts", _skip_if_not_found=True) + self.process_task(entry, args, "unregister", _skip_if_not_found=True) self.process_task(entry, args, "register", overwrite=args.overwrite) self.process_task(entry, args, "add_weights") self.process_task(entry, args, "add_plots") diff --git a/src/anemoi/registry/entry/dataset.py b/src/anemoi/registry/entry/dataset.py index 0d929bc..6ca9d8c 100644 --- a/src/anemoi/registry/entry/dataset.py +++ b/src/anemoi/registry/entry/dataset.py @@ -14,14 +14,26 @@ from anemoi.utils.humanize import when from anemoi.registry import config +from anemoi.registry.rest import RestItemList from . import CatalogueEntry LOG = logging.getLogger(__name__) +COLLECTION = "datasets" + + +class DatasetCatalogueEntryList(RestItemList): + def __init__(self, **kwargs): + super().__init__(COLLECTION, **kwargs) + + def __iter__(self): + for v in self.get(): + yield DatasetCatalogueEntry(key=v["name"]) + class DatasetCatalogueEntry(CatalogueEntry): - collection = "datasets" + collection = COLLECTION main_key = "name" def set_status(self, status): diff --git a/src/anemoi/registry/entry/experiment.py b/src/anemoi/registry/entry/experiment.py index 9dfff14..7530343 100644 --- a/src/anemoi/registry/entry/experiment.py +++ b/src/anemoi/registry/entry/experiment.py @@ -15,15 +15,28 @@ from anemoi.utils.s3 import download from anemoi.utils.s3 import upload +from anemoi.registry.rest import RestItemList + from .. import config from . import CatalogueEntry from .weights import WeightCatalogueEntry +COLLECTION = "experiments" + LOG = logging.getLogger(__name__) +class ExperimentCatalogueEntryList(RestItemList): + def __init__(self, **kwargs): + super().__init__(COLLECTION, **kwargs) + + def __iter__(self): + for v in self.get(): + yield ExperimentCatalogueEntry(key=v["expver"]) + + class ExperimentCatalogueEntry(CatalogueEntry): - collection = "experiments" + collection = COLLECTION main_key = "expver" def create_from_new_key(self, key): diff --git a/src/anemoi/registry/entry/weights.py b/src/anemoi/registry/entry/weights.py index 70b975a..755b77e 100644 --- a/src/anemoi/registry/entry/weights.py +++ b/src/anemoi/registry/entry/weights.py @@ -11,14 +11,27 @@ from anemoi.utils.checkpoints import load_metadata as load_checkpoint_metadata from anemoi.utils.s3 import upload +from anemoi.registry.rest import RestItemList + from .. import config from . import CatalogueEntry +COLLECTION = "weights" + LOG = logging.getLogger(__name__) +class WeightsCatalogueEntryList(RestItemList): + def __init__(self, **kwargs): + super().__init__(COLLECTION, **kwargs) + + def __iter__(self): + for v in self.get(): + yield WeightCatalogueEntry(key=v["uuid"]) + + class WeightCatalogueEntry(CatalogueEntry): - collection = "weights" + collection = COLLECTION main_key = "uuid" def add_location(self, platform, path):