Skip to content

Commit

Permalink
clean workers
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Jul 18, 2024
1 parent 1e6db51 commit 73a65f1
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 39 deletions.
28 changes: 16 additions & 12 deletions src/anemoi/registry/commands/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from anemoi.registry.commands.base import BaseCommand
from anemoi.registry.tasks import TaskCatalogueEntry
from anemoi.registry.utils import list_to_dict
from anemoi.registry.workers import get_worker_class
from anemoi.registry.workers import run_worker

LOG = logging.getLogger(__name__)

Expand All @@ -32,8 +32,6 @@ class WorkerCommand(BaseCommand):
collection = "tasks"

def add_arguments(self, command_parser):
command_parser.add_argument("--timeout", help="Die with timeout (SIGALARM) after TIMEOUT seconds.", type=int)
command_parser.add_argument("--wait", help="Check for new task every WAIT seconds.", type=int, default=60)

subparsers = command_parser.add_subparsers(dest="action", help="Action to perform")

Expand All @@ -43,21 +41,24 @@ def add_arguments(self, command_parser):
)
transfer.add_argument("--published-target-dir", help="The target directory published in the catalogue.")
transfer.add_argument("--destination", help="Platform destination (e.g. leonardo, lumi, marenostrum)")
transfer.add_argument("--threads", help="Number of threads to use", type=int, default=1)
transfer.add_argument("--threads", help="Number of threads to use", type=int)
transfer.add_argument("--filter-tasks", help="Filter tasks to process (key=value list)", nargs="*", default=[])

delete = subparsers.add_parser("delete-dataset", help="Delete dataset")
delete.add_argument("--platform", help="Platform destination (e.g. leonardo, lumi, marenostrum)")
delete.add_argument("--filter-tasks", help="Filter tasks to process (key=value list)", nargs="*", default=[])

for subparser in [transfer, delete]:
subparser.add_argument(
"--filter-tasks", help="Filter tasks to process (key=value list)", nargs="*", default=[]
)
subparser.add_argument("--heartbeat", help="Heartbeat interval", type=int, default=60)
dummy = subparsers.add_parser("dummy", help="Dummy worker for test purposes")
dummy.add_argument("--arg")

for subparser in [transfer, delete, dummy]:
subparser.add_argument("--timeout", help="Die with timeout (SIGALARM) after TIMEOUT seconds.", type=int)
subparser.add_argument("--wait", help="Check for new task every WAIT seconds.", type=int)
subparser.add_argument("--heartbeat", help="Heartbeat interval", type=int)
subparser.add_argument(
"--max-no-heartbeat",
help="Max interval without heartbeat before considering task needs to be freed.",
type=int,
default=0,
)
subparser.add_argument("--loop", help="Run in a loop", action="store_true")
subparser.add_argument(
Expand All @@ -68,11 +69,14 @@ def add_arguments(self, command_parser):

def run(self, args):
kwargs = vars(args)
kwargs["filter_tasks"] = list_to_dict(kwargs["filter_tasks"])
if "filter_tasks" in kwargs:
kwargs["filter_tasks"] = list_to_dict(kwargs["filter_tasks"])
kwargs.pop("command")
kwargs.pop("debug")
kwargs.pop("version")
get_worker_class(kwargs.pop("action"))(**kwargs).run()
action = kwargs.pop("action")
kwargs = {k: v for k, v in kwargs.items() if v is not None}
run_worker(action, **kwargs)


command = WorkerCommand
14 changes: 14 additions & 0 deletions src/anemoi/registry/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,17 @@ registry:
datasets_uri_pattern: "s3://ml-datasets/{name}"
weights_uri_pattern: "s3://ml-weights/{uuid}.ckpt"
weights_platform: "ewc"

workers:
# These are the default values for the workers
# the are experimental and can change in the future
heartbeat: 60
max_no_heartbeat: -1
wait: 10
transfer-dataset:
target_dir: "."
published_target_dir: null
threads: 1
auto_register: true
dummy:
arg: default_value
60 changes: 43 additions & 17 deletions src/anemoi/registry/workers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from anemoi.utils.humanize import when

from anemoi.registry import config
from anemoi.registry.tasks import TaskCatalogueEntryList

# from anemoi.utils.provenance import trace_info
Expand All @@ -26,13 +27,12 @@ class Worker:

def __init__(
self,
heartbeat=60,
max_no_heartbeat=0,
heartbeat,
max_no_heartbeat,
wait,
loop=False,
check_todo=False,
timeout=None,
wait=60,
stop_if_finished=True,
):
"""Run a worker that will process tasks in the queue.
timeout: Kill itself after `timeout` seconds.
Expand All @@ -44,7 +44,6 @@ def __init__(
self.check_todo = check_todo

self.wait = wait
self.stop_if_finished = stop_if_finished
if timeout:
signal.alarm(timeout)
self.filter_tasks = {"action": self.name}
Expand All @@ -66,14 +65,15 @@ def run(self):
if self.loop:
# Process tasks in a loop for ever
while True:
res = self.process_one_task()

if self.stop_if_finished and res is None:
LOG.info("All tasks have been processed, stopping.")
return
try:
self.process_one_task()
LOG.info(f"Waiting {self.wait} seconds before checking again.")
time.sleep(self.wait)
except Exception as e:
LOG.error(f"Error for task {task}: {e}")
LOG.error("Waiting 60 seconds after this error before checking again.")
time.sleep(60)

LOG.info(f"Waiting {self.wait} seconds before checking again.")
time.sleep(self.wait)
else:
# Process one task
self.process_one_task()
Expand Down Expand Up @@ -114,7 +114,7 @@ def send_heartbeat():
thread.start()

try:
self.process_task(task)
self.worker_process_task(task)
finally:
STOP.append(1) # stop the heartbeat thread
thread.join()
Expand Down Expand Up @@ -155,21 +155,47 @@ def choose_task(self):
for task in cat:
updated = datetime.datetime.fromisoformat(task.record["updated"])
LOG.info(f"Task {task.key} is already running, last update {when(updated, use_utc=True)}.")
if (datetime.datetime.utcnow() - updated).total_seconds() > self.max_no_heartbeat:
if (
self.max_no_heartbeat >= 0
and (datetime.datetime.utcnow() - updated).total_seconds() > self.max_no_heartbeat
):
LOG.warning(
f"Task {task.key} has been running for more than {self.max_no_heartbeat} seconds, freeing it."
)
task.release_ownership()

def process_task(self, task):
def worker_process_task(self, task):
raise NotImplementedError("Subclasses must implement this method.")


def get_worker_class(action):
def run_worker(action, **kwargs):
from anemoi.registry.workers.dummy import DummyWorker

from .delete_dataset import DeleteDatasetWorker
from .transfer_dataset import TransferDatasetWorker

return {
workers_config = config().get("workers", {})
worker_config = workers_config.get(action, {})

LOG.debug(kwargs)

for k, v in worker_config.items():
if k not in kwargs:
kwargs[k] = v

LOG.debug(kwargs)

for k, v in workers_config.items():
if isinstance(v, dict):
continue
if k not in kwargs:
kwargs[k] = v

LOG.info(f"Running worker {action} with kwargs {kwargs}")

cls = {
"transfer-dataset": TransferDatasetWorker,
"delete-dataset": DeleteDatasetWorker,
"dummy": DummyWorker,
}[action]
cls(**kwargs).run()
2 changes: 1 addition & 1 deletion src/anemoi/registry/workers/delete_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self.filter_tasks.update(filter_tasks)
self.filter_tasks["platform"] = self.platform

def process_task(self, task):
def worker_process_task(self, task):
platform, dataset = self.parse_task(task)
entry = DatasetCatalogueEntry(key=dataset)
assert platform == self.platform, (platform, self.platform)
Expand Down
23 changes: 23 additions & 0 deletions src/anemoi/registry/workers/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

from . import Worker

LOG = logging.getLogger(__name__)


class DummyWorker(Worker):
name = "dummy"

def __init__(self, arg, **kwargs):
super().__init__(**kwargs)
LOG.warning(f"Dummy worker initialized with kwargs:{kwargs} and args:{arg}")

def worker_process_task(self, task):
LOG.warning(f"Dummy worker processing task={task}")
22 changes: 13 additions & 9 deletions src/anemoi/registry/workers/transfer_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,25 @@ def __init__(
):
super().__init__(**kwargs)

if not destination:
raise ValueError("No destination platform specified")

if not os.path.exists(target_dir):
raise ValueError(f"Target directory {target_dir} must already exist")

self.destination = destination
self.target_dir = target_dir
self.published_target_dir = published_target_dir or target_dir
self.published_target_dir = published_target_dir
self.threads = threads
self.auto_register = auto_register

if self.published_target_dir is None:
self.published_target_dir = self.target_dir

self.filter_tasks.update(filter_tasks)
self.filter_tasks["destination"] = self.destination
self.auto_register = auto_register

def process_task(self, task):
if not self.destination:
raise ValueError("No destination platform specified")

if not os.path.exists(self.target_dir):
raise ValueError(f"Target directory {self.target_dir} must already exist")

def worker_process_task(self, task):
destination, source, dataset = self.parse_task(task)
entry = DatasetCatalogueEntry(key=dataset)

Expand Down

0 comments on commit 73a65f1

Please sign in to comment.