From 260e5f0702aa099054c0a9ef725107b06ac90324 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 2 Jul 2024 15:48:30 +0200 Subject: [PATCH] refactor --- src/anemoi/registry/commands/base.py | 53 +++++---- src/anemoi/registry/commands/datasets.py | 73 ++---------- src/anemoi/registry/commands/experiments.py | 67 +---------- src/anemoi/registry/commands/queues.py | 121 ++++++++++++++++++++ src/anemoi/registry/commands/weights.py | 59 +--------- src/anemoi/registry/entry/__init__.py | 20 ++-- src/anemoi/registry/entry/dataset.py | 5 +- src/anemoi/registry/entry/experiment.py | 6 +- src/anemoi/registry/entry/weights.py | 5 +- 9 files changed, 196 insertions(+), 213 deletions(-) create mode 100644 src/anemoi/registry/commands/queues.py diff --git a/src/anemoi/registry/commands/base.py b/src/anemoi/registry/commands/base.py index f2ce94a..bd756bc 100644 --- a/src/anemoi/registry/commands/base.py +++ b/src/anemoi/registry/commands/base.py @@ -14,7 +14,6 @@ import logging import os -import sys from ..entry import CatalogueEntryNotFound from . import Command @@ -39,35 +38,45 @@ def is_identifier(self, name_or_path): except CatalogueEntryNotFound: return False + def process_task(self, entry, args, k, **kwargs): + assert isinstance(k, str), k + + v = getattr(args, k) + + if v is None: + return + if v is True: + LOG.info(f"{entry.key} : Processing task {k}") + return getattr(entry, k)(**kwargs) + if v is False: + return + if isinstance(v, (str, int, float)): + LOG.info(f"{entry.key} : Processing task {k} with {v}") + return getattr(entry, k)(v, **kwargs) + if isinstance(v, list): + v_str = ", ".join(str(x) for x in v) + LOG.info(f"{entry.key} : Processing task {k} with {v_str}") + return getattr(entry, k)(*v, **kwargs) + if isinstance(v, dict): + v_str = ", ".join(f"{k_}={v_}" for k_, v_ in v.items()) + LOG.info(f"{entry.key} : Processing task {k} with {v_str}") + return getattr(entry, k)(**v, **kwargs) + raise ValueError(f"Invalid task {k}={v}. type(v)= {type(v)}") + def run(self, args): - args = vars(args) LOG.debug("anemoi-registry args:", args) - if "command" in args: - args.pop("command") - name_or_path = args.pop("NAME_OR_PATH") - - if args.get("add_location"): - args["add_location"] = self.parse_location(args["add_location"]) - if args.get("remove_location"): - args["remove_location"] = self.parse_location(args["remove_location"]) + name_or_path = args.NAME_OR_PATH + entry = self.get_entry(name_or_path) + self._run(entry, args) + def get_entry(self, name_or_path): if self.is_path(name_or_path): LOG.info(f"Found local {self.kind} at {name_or_path}") - self.run_from_path(name_or_path, **args) - return + return self.entry_class(path=name_or_path) if self.is_identifier(name_or_path): LOG.info(f"Processing {self.kind} with identifier '{name_or_path}'") - self.run_from_identifier(name_or_path, **args) - return - LOG.error(f"Cannot find any {self.kind} from '{name_or_path}'") - sys.exit(1) - - def parse_location(self, location): - for x in location: - if "=" not in x: - raise ValueError(f"Invalid location format '{x}', use 'key1=value1 key2=value2' list.") - return {x.split("=")[0]: x.split("=")[1] for x in location} + return self.entry_class(key=name_or_path) def warn_unused_arguments(self, kwargs): for k, v in kwargs.items(): diff --git a/src/anemoi/registry/commands/datasets.py b/src/anemoi/registry/commands/datasets.py index 49966af..8ea5507 100644 --- a/src/anemoi/registry/commands/datasets.py +++ b/src/anemoi/registry/commands/datasets.py @@ -44,70 +44,15 @@ def add_arguments(self, command_parser): def check_arguments(self, args): pass - def run_from_identifier( - self, - identifier, - add_location, - add_recipe, - set_status, - unregister, - json, - remove_location=False, - **kwargs, - ): - self.warn_unused_arguments(kwargs) - - entry = self.entry_class(key=identifier) - - if unregister: - entry.unregister() - if add_location: - entry.add_location(**add_location) - if remove_location: - entry.remove_location(**remove_location) - if set_status: - entry.set_status(set_status) - if add_recipe: - entry.add_recipe(add_recipe) - - if json: - print(entry.as_json()) - - def run_from_path( - self, - path, - register, - unregister, - add_location, - add_recipe, - json, - set_status, - # remove_location, - # upload, - # upload_uri_pattern, - **kwargs, - ): - self.warn_unused_arguments(kwargs) - - entry = self.entry_class(path=path) - - if register: - entry.register() - if unregister: - entry.unregister() - if add_location: - entry.add_location(**add_location) - # if remove_location: - # entry.remove_location(**remove_location) - if set_status: - entry.set_status(set_status) - # if delete: - # entry.delete() - if add_recipe: - entry.add_recipe(add_recipe) - - if json: - print(entry.as_json()) + def _run(self, entry, args): + # order matters + self.process_task(entry, args, "unregister") + self.process_task(entry, args, "register") + # self.process_task(entry, args, "remove_location") + self.process_task(entry, args, "add_location") + self.process_task(entry, args, "add_recipe") + self.process_task(entry, args, "set_status") + self.process_task(entry, args, "json") command = Datasets diff --git a/src/anemoi/registry/commands/experiments.py b/src/anemoi/registry/commands/experiments.py index bcef31c..837882e 100644 --- a/src/anemoi/registry/commands/experiments.py +++ b/src/anemoi/registry/commands/experiments.py @@ -52,67 +52,12 @@ def is_path(self, name_or_path): return False return True - def run_from_identifier( - self, - identifier, - json, - add_weights, - add_plots, - unregister, - overwrite, - **kwargs, - ): - self.warn_unused_arguments(kwargs) - - entry = self.entry_class(key=identifier) - - if add_weights: - for w in add_weights: - entry.add_weights(w) - if add_plots: - for p in add_plots: - entry.add_plots(p) - - if unregister: - entry.unregister() - - # if delete: - # entry.delete() - - if json: - print(entry.as_json()) - - def run_from_path( - self, - path, - register, - unregister, - add_weights, - add_plots, - overwrite, - json, - **kwargs, - ): - self.warn_unused_arguments(kwargs) - - entry = self.entry_class(path=path) - - if unregister: - entry.unregister() - if register: - entry.register() - if add_weights: - for w in add_weights: - entry.add_weights(w) - if add_plots: - for p in add_plots: - entry.add_plots(p) - - # if delete: - # entry.delete() - - if json: - print(entry.as_json()) + def _run(self, entry, args): + self.process_task(entry, args, "unregister") + self.process_task(entry, args, "register", overwrite=args.overwrite) + self.process_task(entry, args, "add_weights") + self.process_task(entry, args, "add_plots") + self.process_task(entry, args, "json") command = Experiments diff --git a/src/anemoi/registry/commands/queues.py b/src/anemoi/registry/commands/queues.py new file mode 100644 index 0000000..0040d83 --- /dev/null +++ b/src/anemoi/registry/commands/queues.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +"""Command place holder. Delete when we have real commands. + +""" + +import datetime +import logging + +from anemoi.utils.humanize import when +from anemoi.utils.text import table + +from ..queue_manager import add +from ..queue_manager import disown +from ..queue_manager import get_list +from ..queue_manager import own +from ..queue_manager import remove +from ..queue_manager import set_progress +from ..queue_manager import set_status +from . import Command + +LOG = logging.getLogger(__name__) + + +class Queues(Command): + internal = True + timestamp = True + + def add_arguments(self, command_parser): + command_parser.add_argument("--new", help="add to queue (key=value list)", nargs="+", metavar="K=V") + command_parser.add_argument("--remove", help="remove from queue", nargs=1) + command_parser.add_argument("--set-status", help="--set-status uuid ", nargs=2) + command_parser.add_argument("--set-progress", help="--set-progress uuid ", nargs=2) + + command_parser.add_argument("--own", help="Take ownership of the oldest entry with status=queued", nargs="*") + command_parser.add_argument("--disown", help="Release a task and requeue it", metavar="UUID") + + command_parser.add_argument("--sort", help="Sort by date", choices=["created", "updated"], default="updated") + + command_parser.add_argument("--list", help="List some queue entries", nargs="*") + command_parser.add_argument("-l", "--long", help="Details", action="store_true") + + def run(self, args): + if args.list is not None: + request = {v.split("=")[0]: v.split("=")[1] for v in args.list} + return self.list(request=args.list, long=args.long, sort=args.sort) + + if args.disown: + disown(args.disown) + + if args.own is not None: + request = {v.split("=")[0]: v.split("=")[1] for v in args.own} + self.own(request, sort=args.sort) + + if args.remove: + res = remove(args.remove[0]) + print(res) + return + + if args.new: + record = {v.split("=")[0]: v.split("=")[1] for v in args.new} + res = add(record) + print(res) + + if args.set_status: + uuid, status = args.set_status + set_status(uuid, status) + + if args.set_progress: + uuid, progress = args.set_progress + set_progress(uuid, int(progress)) + + def list(self, request, long=False, sort="updated"): + res = get_list(request) + res = sorted(res, key=lambda x: x[sort]) + + rows = [] + for v in res: + if not isinstance(v, dict): + raise ValueError(v) + created = datetime.datetime.fromisoformat(v.pop("created")) + updated = datetime.datetime.fromisoformat(v.pop("updated")) + + uuid = v.pop("uuid") + content = " ".join(f"{k}={v}" for k, v in v.items()) + if not long: + content = content[:20] + "..." + uuid = uuid[:5] + "..." + rows.append( + [ + when(created), + when(updated), + v.pop("status"), + v.pop("progress", ""), + content, + uuid, + ] + ) + print(table(rows, ["Created", "Updated", "Status", "%", "Details", "UUID"], ["<", "<", "<", "<", "<", "<"])) + return + + def own(request, sort): + if not request: + request = {"status": "queued"} + res = get_list(request) + res = sorted(res, key=lambda x: x[sort]) + uuids = [v["uuid"] for v in res] + latest = uuids.pop() + + own(latest) + + +command = Queues diff --git a/src/anemoi/registry/commands/weights.py b/src/anemoi/registry/commands/weights.py index 0e7473a..8026365 100644 --- a/src/anemoi/registry/commands/weights.py +++ b/src/anemoi/registry/commands/weights.py @@ -55,60 +55,11 @@ def warn_unused_arguments(self, kwargs): if v: LOG.info(f"Ignoring argument {k}={v}") - def run_from_identifier( - self, - identifier, - add_location, - json, - unregister, - remove_location=False, - **kwargs, - ): - self.warn_unused_arguments(kwargs) - - entry = self.entry_class(key=identifier) - - if add_location: - entry.add_location(**add_location) - if remove_location: - entry.remove_location(**remove_location) - if unregister: - entry.unregister() - - if json: - print(entry.as_json()) - - def run_from_path( - self, - path, - unregister, - register, - add_location, - overwrite, - json, - remove_location=False, - **kwargs, - ): - self.warn_unused_arguments(kwargs) - - entry = self.entry_class(path=path) - - if unregister: - entry.unregister() - if register: - entry.register(overwrite=overwrite) - # if upload: - # entry.upload(upload_uri_pattern, **upload) - - if add_location: - entry.add_location(**add_location) - # if remove_location: - # entry.remove_location(**remove_location) - # if delete: - # entry.delete() - - if json: - print(entry.as_json()) + def _run(self, entry, args): + self.process_task(entry, args, "unregister") + self.process_task(entry, args, "register", overwrite=args.overwrite) + self.process_task(entry, args, "add_location") + self.process_task(entry, args, "json") command = Weights diff --git a/src/anemoi/registry/entry/__init__.py b/src/anemoi/registry/entry/__init__.py index d187f4e..44a3b82 100644 --- a/src/anemoi/registry/entry/__init__.py +++ b/src/anemoi/registry/entry/__init__.py @@ -50,6 +50,12 @@ def __init__(self, key=None, path=None): def as_json(self): return json_pretty_dump(self.record) + def list_to_dict(cls, lst): + for x in lst: + if "=" not in x: + raise ValueError(f"Invalid location format '{x}', use 'key1=value1 key2=value2' list.") + return {x.split("=")[0]: x.split("=")[1] for x in lst} + @classmethod def key_exists(cls, key): return RestItem(cls.collection, key).exists() @@ -66,23 +72,21 @@ def load_from_key(self, key): def main_key(self): raise NotImplementedError("Subclasses must implement this property") - def register(self, ignore_existing=True, overwrite=False): + def register(self, overwrite=False, ignore_existing=True): assert self.record, "record must be set" try: return self.rest_collection.post(self.record) except AlreadyExists: + if overwrite is True: + LOG.warning(f"{self.key} already exists. Deleting existing one to overwrite it.") + return self.rest_item.put(self.record) + # self.rest_item.delete() + # return self.register(overwrite=overwrite, ignore_existing=ignore_existing) if ignore_existing: LOG.info(f"{self.key} already exists. Ok.") return - if overwrite is True: - LOG.warning(f"{self.key} already exists. Deleting existing one to overwrite it.") - return self.replace() raise - def replace(self): - assert self.record, "record must be set" - return self.rest_item.put(self.record) - def patch(self, data): return self.rest_item.patch(data) diff --git a/src/anemoi/registry/entry/dataset.py b/src/anemoi/registry/entry/dataset.py index 81cbf11..a8b1cbb 100644 --- a/src/anemoi/registry/entry/dataset.py +++ b/src/anemoi/registry/entry/dataset.py @@ -24,7 +24,10 @@ def set_status(self, status): patch = [{"op": "add", "path": "/status", "value": status}] self.patch(patch) - def add_location(self, platform, path): + def add_location(self, *args, platform=None, path=None): + if args: + return self.add_location(**self.list_to_dict(args)) + patch = [{"op": "add", "path": f"/locations/{platform}", "value": {"path": path}}] self.patch(patch) diff --git a/src/anemoi/registry/entry/experiment.py b/src/anemoi/registry/entry/experiment.py index 63d78ac..3044473 100644 --- a/src/anemoi/registry/entry/experiment.py +++ b/src/anemoi/registry/entry/experiment.py @@ -54,9 +54,11 @@ def add_plots(self, path, target=None): patch = [{"op": "add", "path": "/plots/-", "value": dic}] self.patch(patch) - def add_weights(self, path): - """target is a pattern: s3://bucket/{uuid}""" + def add_weights(self, *paths, **kwargs): + for path in paths: + self.add_one_weights(path, **kwargs) + def _add_one_weights(self, path, **kwargs): weights = WeightCatalogueEntry(path=path) if not WeightCatalogueEntry.key_exists(weights.key): weights.register(ignore_existing=False, overwrite=False) diff --git a/src/anemoi/registry/entry/weights.py b/src/anemoi/registry/entry/weights.py index a66e25d..1220faf 100644 --- a/src/anemoi/registry/entry/weights.py +++ b/src/anemoi/registry/entry/weights.py @@ -21,7 +21,10 @@ class WeightCatalogueEntry(CatalogueEntry): collection = "weights" main_key = "uuid" - def add_location(self, platform, path): + def add_location(self, *args, platform=None, path=None): + if args: + return self.add_location(**self.list_to_dict(args)) + patch = [{"op": "add", "path": f"/locations/{platform}", "value": {"path": path}}] self.patch(patch)