diff --git a/src/anemoi/registry/commands/experiments.py b/src/anemoi/registry/commands/experiments.py index 3cddabd..1ebd1e0 100644 --- a/src/anemoi/registry/commands/experiments.py +++ b/src/anemoi/registry/commands/experiments.py @@ -56,25 +56,35 @@ def add_arguments(self, command_parser): "Add weights to the experiment and upload them do s3." "Skip upload if these weights are already uploaded." ), + metavar="FILE", ) - command_parser.add_argument("--add-plots", nargs="+", help="Add plots to the experiment.") + command_parser.add_argument("--add-plots", nargs="+", help="Add plots to the experiment.", metavar="FILE") command_parser.add_argument( - "--set-archive", help="Input file to register as an archive metadata file to the catalogue." + "--set-archive", help="Input file to register as an archive metadata file to the catalogue.", metavar="FILE" ) command_parser.add_argument( - "--get-archive", help="Output file to save the archive metadata file from the catalogue." + "--get-archive", + help="Output file to save the archive metadata file from the catalogue. Merge metadata file if there are multiple run numbers.", + metavar="FILE", ) + command_parser.add_argument("--remove-archive", help="Delete the archive metadata.", action="store_true") command_parser.add_argument( - "--remove-archive", help="Remove the archive metadata file from the catalogue.", action="store_true" + "--archive-moved", + help="When archive moved to a new location, move the metadata file and update the catalogue.", + nargs=2, + metavar=("OLD", "NEW"), ) + command_parser.add_argument( "--archive-platform", help="Archive platform. Only relevant for --set-archive and --get-archive and --remove-archive.", + metavar="PLATFORM", ) command_parser.add_argument( "--run-number", - help="The run number of the experiment. Relevant --set-archive and --get-archive and --remove-archive.", + help="The run number of the experiment. Relevant --set-archive and --get-archive and --remove-archive. Can be 'all' or 'latest' when applicable.", + metavar="N", ) command_parser.add_argument( "--archive-extra-metadata", help="Extra metadata. A list of key=value pairs.", nargs="+", default={} @@ -106,6 +116,7 @@ def _run(self, entry, args): ) self.process_task(entry, args, "get_archive", run_number=args.run_number, platform=args.archive_platform) self.process_task(entry, args, "remove_archive", run_number=args.run_number, platform=args.archive_platform) + self.process_task(entry, args, "archive_moved", run_number=args.run_number) if args.url: print(entry.url) diff --git a/src/anemoi/registry/entry/experiment.py b/src/anemoi/registry/entry/experiment.py index 1630efa..ba69bc5 100644 --- a/src/anemoi/registry/entry/experiment.py +++ b/src/anemoi/registry/entry/experiment.py @@ -8,6 +8,7 @@ import datetime import logging import os +import tempfile from getpass import getuser import yaml @@ -123,33 +124,82 @@ def set_archive(self, path, platform, run_number, overwrite=True, extras={}): self.rest_item.patch([{"op": "add", "path": f"/runs/{run_number}/archives/{platform}", "value": dic}]) def remove_archive(self, platform, run_number): - if run_number is None: - raise ValueError("run_number must be set") - run_number = str(run_number) - if platform is None: raise ValueError("platform must be set") - LOG.info(f"Removing archive for run {run_number} and platform {platform}") - self.rest_item.patch([{"op": "remove", "path": f"/runs/{run_number}/archives/{platform}"}]) + run_numbers = self._parse_run_number(run_number) - def get_archive(self, path, *, platform, run_number): - if os.path.exists(path): - raise FileExistsError(f"Path {path} already exists") + for run_number in run_numbers: + LOG.info(f"Removing archive for run {run_number} and platform {platform}") + if run_number not in self.record["runs"]: + LOG.info(f"Archive: skipping run {run_number} because it does not exist") + continue + run_record = self.record["runs"][run_number] + + if platform not in run_record.get("archives", {}): + LOG.info(f"Archive: skipping {platform} for run {run_number} because it does not exist") + continue + + url = run_record["archives"][platform]["url"] + delete(url) + self.rest_item.patch([{"op": "remove", "path": f"/runs/{run_number}/archives/{platform}"}]) + + def _list_run_numbers(self): + return [int(k) for k in self.record.get("runs", {}).keys()] + def _parse_run_number(self, run_number): + assert isinstance(run_number, (str, int)), "run_number must be a string or an integer" run_number = str(run_number) + + if run_number.lower() == "all": + return [str(i) for i in self._list_run_numbers()] + if run_number == "latest": - run_number = str(max([int(k) for k in self.record["runs"].keys()])) + run_number = str(max(self._list_run_numbers())) LOG.info(f"Using latest run number {run_number}") + if run_number not in self.record["runs"]: raise ValueError(f"Run number {run_number} not found") - if platform not in self.record["runs"][run_number]["archives"]: - raise ValueError(f"Platform {platform} not found") + return [run_number] + + def archive_moved(self, old, new, run_number, overwrite=None): + run_numbers = self._parse_run_number(run_number) + + with tempfile.TemporaryDirectory() as tmpdir: + print(tmpdir) + for run_number in run_numbers: + tmp_path = os.path.join(tmpdir, str(run_number)) + self.get_archive(tmp_path, platform=old, run_number=run_number) + self.set_archive(tmp_path, platform=new, run_number=run_number, overwrite=overwrite) + self.remove_archive(old, run_number) + + def _get_run_record(self, run_number): + print(self.record.get("runs", {}), run_number, type(run_number)) + print(self.record.get("runs", {}).get(run_number, {})) + return self.record.get("runs", {}).get(run_number, {}) + + def get_archive(self, path, *, platform, run_number): + if os.path.exists(path): + raise FileExistsError(f"Path {path} already exists") + + with tempfile.TemporaryDirectory() as tmpdir: + run_numbers = self._parse_run_number(run_number) + for run_number in run_numbers: + run_record = self._get_run_record(run_number) + + if platform not in run_record.get("archives", {}): + LOG.info(f"Archive: skipping {platform} for run {run_number} because it does not exist") + continue + + tmp_path = os.path.join(tmpdir, str(run_number)) - url = self.record["runs"][run_number]["archives"][platform]["url"] - LOG.info(f"Downloading {url} to {path}.") - download(url, path) + url = run_record["archives"][platform]["url"] + LOG.info(f"Downloading {url} to {tmp_path}.") + download(url, tmp_path) + with open(path, "a+") as f: + with open(tmp_path, "r") as tmp: + f.write(tmp.read()) def delete_artefacts(self): self.delete_all_plots()