Skip to content

Commit

Permalink
simplify api, and make more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Jul 22, 2024
1 parent 08b59d9 commit 40bdac2
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 7 deletions.
1 change: 0 additions & 1 deletion docs/cli/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ datasets
========



.. argparse::
:module: anemoi.registry.__main__
:func: create_parser
Expand Down
21 changes: 21 additions & 0 deletions src/anemoi/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,24 @@ def config():
default_config = os.path.join(os.path.dirname(__file__), "config.yaml")
config = load_config(secrets=["api_token"], defaults=default_config)
return config.get("registry")


from .entry.dataset import DatasetCatalogueEntry as Dataset
from .entry.dataset import DatasetCatalogueEntryList as DatasetsList
from .entry.experiment import ExperimentCatalogueEntry as Experiment
from .entry.experiment import ExperimentCatalogueEntryList as ExperimentsList
from .entry.weights import WeightCatalogueEntry as Weights
from .entry.weights import WeightsCatalogueEntryList as WeightsList
from .tasks import TaskCatalogueEntry as Task
from .tasks import TaskCatalogueEntryList as TasksList

__all__ = [
"Weights",
"WeightsList",
"Experiment",
"ExperimentsList",
"Dataset",
"DatasetsList",
"Task",
"TasksList",
]
5 changes: 4 additions & 1 deletion src/anemoi/registry/commands/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def is_identifier(self, name_or_path):
except CatalogueEntryNotFound:
return False

def process_task(self, entry, args, k, func_name=None, /, **kwargs):
def process_task(self, entry, args, k, func_name=None, /, _skip_if_not_found=False, **kwargs):
"""
Call the method `k` on the entry object.
The args/kwargs given to the method are extracted from from the argument `k` in the `args` object.
Expand All @@ -46,6 +46,9 @@ def process_task(self, entry, args, k, func_name=None, /, **kwargs):
The provided **kwargs are also passed to the method.
The method name can be changed by providing the `func_name` argument.
"""
if entry is None and _skip_if_not_found:
LOG.warning(f"Cannot find entry {args.NAME_OR_PATH}. Skipping {k}.")
return

assert isinstance(k, str), k
if func_name is None:
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/registry/commands/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def is_path(self, name_or_path):
return True

def _run(self, entry, args):
self.process_task(entry, args, "delete_artefacts")
self.process_task(entry, args, "unregister")
self.process_task(entry, args, "delete_artefacts", _skip_if_not_found=True)
self.process_task(entry, args, "unregister", _skip_if_not_found=True)
self.process_task(entry, args, "register", overwrite=args.overwrite)
self.process_task(entry, args, "add_weights")
self.process_task(entry, args, "add_plots")
Expand Down
14 changes: 13 additions & 1 deletion src/anemoi/registry/entry/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,26 @@
from anemoi.utils.humanize import when

from anemoi.registry import config
from anemoi.registry.rest import RestItemList

from . import CatalogueEntry

LOG = logging.getLogger(__name__)

COLLECTION = "datasets"


class DatasetCatalogueEntryList(RestItemList):
def __init__(self, **kwargs):
super().__init__(COLLECTION, **kwargs)

def __iter__(self):
for v in self.get():
yield DatasetCatalogueEntry(key=v["name"])


class DatasetCatalogueEntry(CatalogueEntry):
collection = "datasets"
collection = COLLECTION
main_key = "name"

def set_status(self, status):
Expand Down
15 changes: 14 additions & 1 deletion src/anemoi/registry/entry/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,28 @@
from anemoi.utils.s3 import download
from anemoi.utils.s3 import upload

from anemoi.registry.rest import RestItemList

from .. import config
from . import CatalogueEntry
from .weights import WeightCatalogueEntry

COLLECTION = "experiments"

LOG = logging.getLogger(__name__)


class ExperimentCatalogueEntryList(RestItemList):
def __init__(self, **kwargs):
super().__init__(COLLECTION, **kwargs)

def __iter__(self):
for v in self.get():
yield ExperimentCatalogueEntry(key=v["expver"])


class ExperimentCatalogueEntry(CatalogueEntry):
collection = "experiments"
collection = COLLECTION
main_key = "expver"

def create_from_new_key(self, key):
Expand Down
15 changes: 14 additions & 1 deletion src/anemoi/registry/entry/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,27 @@
from anemoi.utils.checkpoints import load_metadata as load_checkpoint_metadata
from anemoi.utils.s3 import upload

from anemoi.registry.rest import RestItemList

from .. import config
from . import CatalogueEntry

COLLECTION = "weights"

LOG = logging.getLogger(__name__)


class WeightsCatalogueEntryList(RestItemList):
def __init__(self, **kwargs):
super().__init__(COLLECTION, **kwargs)

def __iter__(self):
for v in self.get():
yield WeightCatalogueEntry(key=v["uuid"])


class WeightCatalogueEntry(CatalogueEntry):
collection = "weights"
collection = COLLECTION
main_key = "uuid"

def add_location(self, platform, path):
Expand Down

0 comments on commit 40bdac2

Please sign in to comment.