Skip to content

Commit

Permalink
workers
Browse files Browse the repository at this point in the history
up
  • Loading branch information
floriankrb committed Jul 15, 2024
1 parent 75c2f91 commit c19de87
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 177 deletions.
59 changes: 31 additions & 28 deletions src/anemoi/registry/commands/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from anemoi.registry.commands.base import BaseCommand
from anemoi.registry.tasks import TaskCatalogueEntry
from anemoi.registry.utils import list_to_dict
from anemoi.registry.workers import TransferDatasetWorker
from anemoi.registry.workers import get_worker_class

LOG = logging.getLogger(__name__)

Expand All @@ -35,41 +35,44 @@ def add_arguments(self, command_parser):
command_parser.add_argument("--timeout", help="Die with timeout (SIGALARM) after TIMEOUT seconds.", type=int)
command_parser.add_argument("--wait", help="Check for new task every WAIT seconds.", type=int, default=60)

command_parser.add_argument(
"action",
help="Action to perform",
choices=["transfer-dataset", "delete-dataset"],
nargs="?",
)
command_parser.add_argument(
subparsers = command_parser.add_subparsers(dest="action", help="Action to perform")

transfer = subparsers.add_parser("transfer-dataset", help="Transfer dataset")
transfer.add_argument(
"--target-dir", help="The actual target directory where the worker will write.", default="."
)
command_parser.add_argument("--published-target-dir", help="The target directory published in the catalogue.")
command_parser.add_argument("--destination", help="Platform destination (e.g. leonardo, lumi, marenostrum)")
command_parser.add_argument("--request", help="Filter tasks to process (key=value list)", nargs="*", default=[])
command_parser.add_argument("--threads", help="Number of threads to use", type=int, default=1)
command_parser.add_argument("--heartbeat", help="Heartbeat interval", type=int, default=60)
command_parser.add_argument(
"--max-no-heartbeat",
help="Max interval without heartbeat before considering task needs to be freed.",
type=int,
default=0,
)
command_parser.add_argument("--loop", help="Run in a loop", action="store_true")
command_parser.add_argument(
"--check-todo",
help="See if there are tasks for this worker and exit with 0 if there are task to do.",
action="store_true",
)
transfer.add_argument("--published-target-dir", help="The target directory published in the catalogue.")
transfer.add_argument("--destination", help="Platform destination (e.g. leonardo, lumi, marenostrum)")
transfer.add_argument("--threads", help="Number of threads to use", type=int, default=1)

delete = subparsers.add_parser("delete-dataset", help="Delete dataset")
delete.add_argument("--platform", help="Platform destination (e.g. leonardo, lumi, marenostrum)")

for subparser in [transfer, delete]:
subparser.add_argument(
"--filter-tasks", help="Filter tasks to process (key=value list)", nargs="*", default=[]
)
subparser.add_argument("--heartbeat", help="Heartbeat interval", type=int, default=60)
subparser.add_argument(
"--max-no-heartbeat",
help="Max interval without heartbeat before considering task needs to be freed.",
type=int,
default=0,
)
subparser.add_argument("--loop", help="Run in a loop", action="store_true")
subparser.add_argument(
"--check-todo",
help="See if there are tasks for this worker and exit with 0 if there are task to do.",
action="store_true",
)

def run(self, args):
kwargs = vars(args)
kwargs["request"] = list_to_dict(kwargs["request"])
kwargs["filter_tasks"] = list_to_dict(kwargs["filter_tasks"])
kwargs.pop("command")
kwargs.pop("debug")
kwargs.pop("version")

TransferDatasetWorker(**kwargs).run()
get_worker_class(kwargs.pop("action"))(**kwargs).run()


command = WorkerCommand
3 changes: 3 additions & 0 deletions src/anemoi/registry/entry/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def set_status(self, status):
def add_location(self, path, platform):
self.rest_item.patch([{"op": "add", "path": f"/locations/{platform}", "value": {"path": path}}])

def remove_location(self, platform):
self.rest_item.patch([{"op": "remove", "path": f"/locations/{platform}"}])

def set_recipe(self, file):
if not os.path.exists(file):
raise FileNotFoundError(f"Recipe file not found: {file}")
Expand Down
205 changes: 56 additions & 149 deletions src/anemoi/registry/workers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@

import datetime
import logging
import os
import signal
import sys
import threading
import time

from anemoi.utils.humanize import when

from anemoi.registry.entry.dataset import DatasetCatalogueEntry
from anemoi.registry.tasks import TaskCatalogueEntryList

# from anemoi.utils.provenance import trace_info
Expand All @@ -24,10 +22,10 @@


class Worker:
name = None

def __init__(
self,
action,
# generic worker options
heartbeat=60,
max_no_heartbeat=0,
loop=False,
Expand All @@ -49,9 +47,14 @@ def __init__(
self.stop_if_finished = stop_if_finished
if timeout:
signal.alarm(timeout)
self.filter_tasks = {"action": self.name}

def run(self):

if self.check_todo:
# Check if there are tasks to do
# exit with 0 if there are.
# exit with 1 if there are none.
task = self.choose_task()
if task:
LOG.info("There are tasks to do.")
Expand All @@ -61,6 +64,7 @@ def run(self):
sys.exit(1)

if self.loop:
# Process tasks in a loop for ever
while True:
res = self.process_one_task()

Expand All @@ -71,59 +75,33 @@ def run(self):
LOG.info(f"Waiting {self.wait} seconds before checking again.")
time.sleep(self.wait)
else:
# Process one task
self.process_one_task()

def choose_task(self):
request = self.request.copy()
request["destination"] = request.get("destination", self.destination)
request["action"] = "transfer-dataset"

# if a task is queued, take it
for entry in TaskCatalogueEntryList(status="queued", **request):
return entry

# else if a task is running, check if it has been running for too long, and free it
if self.max_no_heartbeat == 0:
return None

cat = TaskCatalogueEntryList(status="running", **request)
if not cat:
LOG.info("No queued tasks found")
else:
LOG.info(cat.to_str(long=True))
for entry in cat:
updated = datetime.datetime.fromisoformat(entry.record["updated"])
LOG.info(f"Task {entry.key} is already running, last update {when(updated, use_utc=True)}.")
if (datetime.datetime.utcnow() - updated).total_seconds() > self.max_no_heartbeat:
LOG.warning(
f"Task {entry.key} has been running for more than {self.max_no_heartbeat} seconds, freeing it."
)
entry.release_ownership()

def process_one_task(self):
entry = self.choose_task()
if not entry:
task = self.choose_task()
if not task:
return False

uuid = entry.key
LOG.info(f"Processing task {uuid}: {entry}")
self.parse_entry(entry) # for checking only
uuid = task.key
LOG.info(f"Processing task {uuid}: {task}")
self.parse_task(task) # for checking only

entry.take_ownership()
self.process_entry_with_heartbeat(entry)
task.take_ownership()
self.process_task_with_heartbeat(task)
LOG.info(f"Task {uuid} completed.")
entry.unregister()
task.unregister()
LOG.info(f"Task {uuid} deleted.")
return True

def process_entry_with_heartbeat(self, entry):
def process_task_with_heartbeat(self, task):
STOP = []

# create another thread to send heartbeat
def send_heartbeat():
while True:
try:
entry.set_status("running")
task.set_status("running")
except Exception:
return
for _ in range(self.heartbeat):
Expand All @@ -136,133 +114,62 @@ def send_heartbeat():
thread.start()

try:
self.process_entry(entry)
self.process_task(task)
finally:
STOP.append(1) # stop the heartbeat thread
thread.join()

def process_entry(self, entry):
destination, source, dataset = self.parse_entry(entry)
dataset_entry = DatasetCatalogueEntry(key=dataset)

LOG.info(f"Transferring {dataset} from '{source}' to '{destination}'")

def get_source_path():
e = dataset_entry.record
if "locations" not in e:
raise ValueError(f"Dataset {dataset} has no locations")
locations = e["locations"]

if source not in locations:
raise ValueError(
f"Dataset {dataset} is not available at {source}. Available locations: {list(locations.keys())}"
)

if "path" not in locations[source]:
raise ValueError(f"Dataset {dataset} has no path at {source}")

path = locations[source]["path"]

return path

source_path = get_source_path()
basename = os.path.basename(source_path)
target_path = os.path.join(self.target_dir, basename)
if os.path.exists(target_path):
LOG.error(f"Target path {target_path} already exists, skipping.")
return

from anemoi.utils.s3 import download

LOG.info(f"Source path: {source_path}")
LOG.info(f"Target path: {target_path}")

if source_path.startswith("s3://"):
source_path = source_path + "/" if not source_path.endswith("/") else source_path

if target_path.startswith("s3://"):
LOG.warning("Uploading to S3 is experimental and has not been tested yet.")
download(source_path, target_path, resume=True, threads=self.threads)
return
else:
target_tmp_path = os.path.join(self.target_dir + "-downloading", basename)
os.makedirs(os.path.dirname(target_tmp_path), exist_ok=True)
download(source_path, target_tmp_path, resume=True, threads=self.threads)
os.rename(target_tmp_path, target_path)

if self.auto_register:
published_target_path = os.path.join(self.published_target_dir, basename)
dataset_entry.add_location(platform=destination, path=published_target_path)

@classmethod
def parse_entry(cls, entry):
data = entry.record.copy()

def parse_task(cls, task, *keys):
data = task.record.copy()
assert isinstance(data, dict), data
assert data["action"] == "transfer-dataset", data["action"]

def is_alphanumeric(s):
assert isinstance(s, str), s
return all(c.isalnum() or c in ("-", "_") for c in s)

destination = data.pop("destination")
source = data.pop("source")
dataset = data.pop("dataset")
assert is_alphanumeric(destination), destination
assert is_alphanumeric(source), source
assert is_alphanumeric(dataset), dataset
for k in keys:
value = data.pop(k)
assert is_alphanumeric(value), (k, value)
for k in data:
if k not in ("action", "status", "progress", "created", "updated", "uuid"):
LOG.warning(f"Unknown key {k}=data[k]")
data = None

if "/" in destination:
raise ValueError(f"Destination {destination} must not contain '/', this is a platform name")
if "." in destination:
raise ValueError(f"Destination {destination} must not contain '.', this is a platform name")
return [task.record[k] for k in keys]

if "/" in source:
raise ValueError(f"Source {source} must not contain '/', this is a platform name")
if "." in source:
raise ValueError(f"Source {source} must not contain '.', this is a platform name")

if "." in dataset:
raise ValueError(f"The dataset {dataset} must not contain a '.', this is the name of the dataset.")
def choose_task(self):
for task in TaskCatalogueEntryList(status="queued", **self.filter_tasks):
LOG.info("Found task")
return task
LOG.info("No queued tasks found")

assert isinstance(destination, str), destination
assert isinstance(source, str), source
assert isinstance(dataset, str), dataset
return destination, source, dataset
if self.max_no_heartbeat == 0:
return None

cat = TaskCatalogueEntryList(status="running", **self.filter_tasks)
if not cat:
LOG.info("No queued tasks found")
else:
LOG.info(cat.to_str(long=True))

class TransferDatasetWorker(Worker):
def __init__(
self,
action,
# specific worker options
destination,
target_dir=".",
published_target_dir=None,
auto_register=True,
threads=1,
request={},
**kwargs,
):
super().__init__(action, **kwargs)
# if a task is running, check if it has been running for too long, and free it
for task in cat:
updated = datetime.datetime.fromisoformat(task.record["updated"])
LOG.info(f"Task {task.key} is already running, last update {when(updated, use_utc=True)}.")
if (datetime.datetime.utcnow() - updated).total_seconds() > self.max_no_heartbeat:
LOG.warning(
f"Task {task.key} has been running for more than {self.max_no_heartbeat} seconds, freeing it."
)
task.release_ownership()

assert action == "transfer-dataset", action
def process_task(self, task):
raise NotImplementedError("Subclasses must implement this method.")

if not destination:
raise ValueError("No destination platform specified")
if not action:
raise ValueError("No action specified")

self.destination = destination
self.target_dir = target_dir
self.published_target_dir = published_target_dir or target_dir
self.request = request
self.threads = threads
def get_worker_class(action):
from .delete_dataset import DeleteDatasetWorker
from .transfer_dataset import TransferDatasetWorker

self.auto_register = auto_register
if not os.path.exists(target_dir):
raise ValueError(f"Target directory {target_dir} must already exist")
return {
"transfer-dataset": TransferDatasetWorker,
"delete-dataset": DeleteDatasetWorker,
}[action]
Loading

0 comments on commit c19de87

Please sign in to comment.