diff --git a/src/anemoi/registry/commands/worker.py b/src/anemoi/registry/commands/worker.py index c2b866e..1be55a2 100644 --- a/src/anemoi/registry/commands/worker.py +++ b/src/anemoi/registry/commands/worker.py @@ -17,7 +17,7 @@ from anemoi.registry.commands.base import BaseCommand from anemoi.registry.tasks import TaskCatalogueEntry from anemoi.registry.utils import list_to_dict -from anemoi.registry.workers import get_worker_class +from anemoi.registry.workers import run_worker LOG = logging.getLogger(__name__) @@ -32,8 +32,6 @@ class WorkerCommand(BaseCommand): collection = "tasks" def add_arguments(self, command_parser): - command_parser.add_argument("--timeout", help="Die with timeout (SIGALARM) after TIMEOUT seconds.", type=int) - command_parser.add_argument("--wait", help="Check for new task every WAIT seconds.", type=int, default=60) subparsers = command_parser.add_subparsers(dest="action", help="Action to perform") @@ -43,21 +41,24 @@ def add_arguments(self, command_parser): ) transfer.add_argument("--published-target-dir", help="The target directory published in the catalogue.") transfer.add_argument("--destination", help="Platform destination (e.g. leonardo, lumi, marenostrum)") - transfer.add_argument("--threads", help="Number of threads to use", type=int, default=1) + transfer.add_argument("--threads", help="Number of threads to use", type=int) + transfer.add_argument("--filter-tasks", help="Filter tasks to process (key=value list)", nargs="*", default=[]) delete = subparsers.add_parser("delete-dataset", help="Delete dataset") delete.add_argument("--platform", help="Platform destination (e.g. leonardo, lumi, marenostrum)") + delete.add_argument("--filter-tasks", help="Filter tasks to process (key=value list)", nargs="*", default=[]) - for subparser in [transfer, delete]: - subparser.add_argument( - "--filter-tasks", help="Filter tasks to process (key=value list)", nargs="*", default=[] - ) - subparser.add_argument("--heartbeat", help="Heartbeat interval", type=int, default=60) + dummy = subparsers.add_parser("dummy", help="Dummy worker for test purposes") + dummy.add_argument("--arg") + + for subparser in [transfer, delete, dummy]: + subparser.add_argument("--timeout", help="Die with timeout (SIGALARM) after TIMEOUT seconds.", type=int) + subparser.add_argument("--wait", help="Check for new task every WAIT seconds.", type=int) + subparser.add_argument("--heartbeat", help="Heartbeat interval", type=int) subparser.add_argument( "--max-no-heartbeat", help="Max interval without heartbeat before considering task needs to be freed.", type=int, - default=0, ) subparser.add_argument("--loop", help="Run in a loop", action="store_true") subparser.add_argument( @@ -68,11 +69,14 @@ def add_arguments(self, command_parser): def run(self, args): kwargs = vars(args) - kwargs["filter_tasks"] = list_to_dict(kwargs["filter_tasks"]) + if "filter_tasks" in kwargs: + kwargs["filter_tasks"] = list_to_dict(kwargs["filter_tasks"]) kwargs.pop("command") kwargs.pop("debug") kwargs.pop("version") - get_worker_class(kwargs.pop("action"))(**kwargs).run() + action = kwargs.pop("action") + kwargs = {k: v for k, v in kwargs.items() if v is not None} + run_worker(action, **kwargs) command = WorkerCommand diff --git a/src/anemoi/registry/config.yaml b/src/anemoi/registry/config.yaml index b7d82ab..7720c09 100644 --- a/src/anemoi/registry/config.yaml +++ b/src/anemoi/registry/config.yaml @@ -8,3 +8,17 @@ registry: datasets_uri_pattern: "s3://ml-datasets/{name}" weights_uri_pattern: "s3://ml-weights/{uuid}.ckpt" weights_platform: "ewc" + + workers: + # These are the default values for the workers + # the are experimental and can change in the future + heartbeat: 60 + max_no_heartbeat: -1 + wait: 10 + transfer-dataset: + target_dir: "." + published_target_dir: null + threads: 1 + auto_register: true + dummy: + arg: default_value diff --git a/src/anemoi/registry/workers/__init__.py b/src/anemoi/registry/workers/__init__.py index c600f50..b39e50f 100644 --- a/src/anemoi/registry/workers/__init__.py +++ b/src/anemoi/registry/workers/__init__.py @@ -14,6 +14,7 @@ from anemoi.utils.humanize import when +from anemoi.registry import config from anemoi.registry.tasks import TaskCatalogueEntryList # from anemoi.utils.provenance import trace_info @@ -26,13 +27,12 @@ class Worker: def __init__( self, - heartbeat=60, - max_no_heartbeat=0, + heartbeat, + max_no_heartbeat, + wait, loop=False, check_todo=False, timeout=None, - wait=60, - stop_if_finished=True, ): """Run a worker that will process tasks in the queue. timeout: Kill itself after `timeout` seconds. @@ -44,7 +44,6 @@ def __init__( self.check_todo = check_todo self.wait = wait - self.stop_if_finished = stop_if_finished if timeout: signal.alarm(timeout) self.filter_tasks = {"action": self.name} @@ -66,14 +65,15 @@ def run(self): if self.loop: # Process tasks in a loop for ever while True: - res = self.process_one_task() - - if self.stop_if_finished and res is None: - LOG.info("All tasks have been processed, stopping.") - return + try: + self.process_one_task() + LOG.info(f"Waiting {self.wait} seconds before checking again.") + time.sleep(self.wait) + except Exception as e: + LOG.error(f"Error for task {task}: {e}") + LOG.error("Waiting 60 seconds after this error before checking again.") + time.sleep(60) - LOG.info(f"Waiting {self.wait} seconds before checking again.") - time.sleep(self.wait) else: # Process one task self.process_one_task() @@ -114,7 +114,7 @@ def send_heartbeat(): thread.start() try: - self.process_task(task) + self.worker_process_task(task) finally: STOP.append(1) # stop the heartbeat thread thread.join() @@ -155,21 +155,47 @@ def choose_task(self): for task in cat: updated = datetime.datetime.fromisoformat(task.record["updated"]) LOG.info(f"Task {task.key} is already running, last update {when(updated, use_utc=True)}.") - if (datetime.datetime.utcnow() - updated).total_seconds() > self.max_no_heartbeat: + if ( + self.max_no_heartbeat >= 0 + and (datetime.datetime.utcnow() - updated).total_seconds() > self.max_no_heartbeat + ): LOG.warning( f"Task {task.key} has been running for more than {self.max_no_heartbeat} seconds, freeing it." ) task.release_ownership() - def process_task(self, task): + def worker_process_task(self, task): raise NotImplementedError("Subclasses must implement this method.") -def get_worker_class(action): +def run_worker(action, **kwargs): + from anemoi.registry.workers.dummy import DummyWorker + from .delete_dataset import DeleteDatasetWorker from .transfer_dataset import TransferDatasetWorker - return { + workers_config = config().get("workers", {}) + worker_config = workers_config.get(action, {}) + + LOG.debug(kwargs) + + for k, v in worker_config.items(): + if k not in kwargs: + kwargs[k] = v + + LOG.debug(kwargs) + + for k, v in workers_config.items(): + if isinstance(v, dict): + continue + if k not in kwargs: + kwargs[k] = v + + LOG.info(f"Running worker {action} with kwargs {kwargs}") + + cls = { "transfer-dataset": TransferDatasetWorker, "delete-dataset": DeleteDatasetWorker, + "dummy": DummyWorker, }[action] + cls(**kwargs).run() diff --git a/src/anemoi/registry/workers/delete_dataset.py b/src/anemoi/registry/workers/delete_dataset.py index 75b4d4f..6bfc820 100644 --- a/src/anemoi/registry/workers/delete_dataset.py +++ b/src/anemoi/registry/workers/delete_dataset.py @@ -33,7 +33,7 @@ def __init__( self.filter_tasks.update(filter_tasks) self.filter_tasks["platform"] = self.platform - def process_task(self, task): + def worker_process_task(self, task): platform, dataset = self.parse_task(task) entry = DatasetCatalogueEntry(key=dataset) assert platform == self.platform, (platform, self.platform) diff --git a/src/anemoi/registry/workers/dummy.py b/src/anemoi/registry/workers/dummy.py new file mode 100644 index 0000000..19c3a63 --- /dev/null +++ b/src/anemoi/registry/workers/dummy.py @@ -0,0 +1,23 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +from . import Worker + +LOG = logging.getLogger(__name__) + + +class DummyWorker(Worker): + name = "dummy" + + def __init__(self, arg, **kwargs): + super().__init__(**kwargs) + LOG.warning(f"Dummy worker initialized with kwargs:{kwargs} and args:{arg}") + + def worker_process_task(self, task): + LOG.warning(f"Dummy worker processing task={task}") diff --git a/src/anemoi/registry/workers/transfer_dataset.py b/src/anemoi/registry/workers/transfer_dataset.py index 1c05bfb..0b98abc 100644 --- a/src/anemoi/registry/workers/transfer_dataset.py +++ b/src/anemoi/registry/workers/transfer_dataset.py @@ -78,21 +78,25 @@ def __init__( ): super().__init__(**kwargs) - if not destination: - raise ValueError("No destination platform specified") - - if not os.path.exists(target_dir): - raise ValueError(f"Target directory {target_dir} must already exist") - self.destination = destination self.target_dir = target_dir - self.published_target_dir = published_target_dir or target_dir + self.published_target_dir = published_target_dir self.threads = threads + self.auto_register = auto_register + + if self.published_target_dir is None: + self.published_target_dir = self.target_dir + self.filter_tasks.update(filter_tasks) self.filter_tasks["destination"] = self.destination - self.auto_register = auto_register - def process_task(self, task): + if not self.destination: + raise ValueError("No destination platform specified") + + if not os.path.exists(self.target_dir): + raise ValueError(f"Target directory {self.target_dir} must already exist") + + def worker_process_task(self, task): destination, source, dataset = self.parse_task(task) entry = DatasetCatalogueEntry(key=dataset)