Skip to content

Commit

Permalink
compute constants
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 20, 2024
1 parent d2468eb commit 6d8cc2a
Showing 1 changed file with 71 additions and 28 deletions.
99 changes: 71 additions & 28 deletions src/anemoi/registry/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ def add_arguments(self, command_parser):

group = command_parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--catalogue-from-recipe-file", help="Update the catalogue entry from the recipe.", action="store_true"
"-R",
"--catalogue-from-recipe-file",
help="Update the catalogue entry from the recipe.",
action="store_true",
)

group.add_argument(
"-Z",
"--zarr-file-from-catalogue",
help="Update the zarr file metadata from a catalogue entry.",
action="store_true",
Expand All @@ -51,6 +55,8 @@ def add_arguments(self, command_parser):
command_parser.add_argument("--force", help="Force.", action="store_true")
command_parser.add_argument("--update", help="Update.", action="store_true")
command_parser.add_argument("--ignore", help="Ignore some trivial errors.", action="store_true")
command_parser.add_argument("--resume", help="Resume from progress", action="store_true")
command_parser.add_argument("--progress", help="Progress file")

command_parser.add_argument(
"--continue", help="Continue to the next file on error.", action="store_true", dest="continue_"
Expand All @@ -60,27 +66,37 @@ def add_arguments(self, command_parser):
command_parser.add_argument("paths", nargs="*", help="Paths to update.")

def run(self, args):
if args.catalogue_from_recipe_file:
for path in args.paths:
try:
self.catalogue_from_recipe_file(path, args)
except Exception as e:
if args.continue_:
LOG.exception(e)
continue
raise
return

if args.zarr_file_from_catalogue:
for path in args.paths:
try:
self.zarr_file_from_catalogue(path, args)
except Exception as e:
if args.continue_:
LOG.exception(e)
continue
raise
return
if args.resume:
if args.progress is None:
LOG.error("Progress file is required for --resume")
return

done = set()
if os.path.exists(args.progress):
with open(args.progress) as f:
for line in f:
done.add(line.strip())

if args.catalogue_from_recipe_file:
method = self.catalogue_from_recipe_file
elif args.zarr_file_from_catalogue:
method = self.zarr_file_from_catalogue

for path in args.paths:
if args.resume and path in done:
LOG.info(f"Skipping {path}")
continue
try:
method(path, args)
except Exception as e:
if args.continue_:
LOG.exception(e)
continue
raise
if args.progress:
with open(args.progress, "a") as f:
print(path, file=f)

def _error(self, args, message):
LOG.error(message)
Expand All @@ -92,6 +108,7 @@ def _error(self, args, message):
def catalogue_from_recipe_file(self, path, args):
"""Update the catalogue entry a recipe file."""

from anemoi.datasets import open_dataset
from anemoi.datasets.create import creator_factory

def entry_set_value(path, value):
Expand All @@ -108,6 +125,7 @@ def entry_set_value(path, value):

if "name" not in recipe:
self._error(args, "Recipe does not contain a 'name' field.")
return

name = recipe["name"]
base, _ = os.path.splitext(os.path.basename(path))
Expand All @@ -118,7 +136,7 @@ def entry_set_value(path, value):
try:
entry = Dataset(name, params={"_": True})
except CatalogueEntryNotFound:
if args.force:
if args.ignore:
LOG.error(f"Entry not found: {name}")
return
raise
Expand All @@ -131,6 +149,7 @@ def entry_set_value(path, value):
return

if "recipe" not in entry.record["metadata"] or args.force:
LOG.info("%s, setting `constant_fields` πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯", name)
if args.dry_run:
LOG.info("Would set recipe %s", name)
else:
Expand All @@ -139,12 +158,29 @@ def entry_set_value(path, value):
entry_set_value("/metadata/recipe", recipe)
entry_set_value("/metadata/updated", updated + 1)

if "variables_metadata" in entry.record["_original"]["metadata"]:
LOG.info("%s: `variables_metadata` already in original. Use --force and --update to update", name)
if not args.update or not args.force:
return
if "constant_fields" in entry.record["metadata"] and "variables_metadata" in entry.record["metadata"]:
LOG.info("%s, setting `variables_metadata` and `constant_fields`")
constants = entry.record["metadata"]["constant_fields"]
variables_metadata = entry.record["metadata"]["variables_metadata"]

changed = False
for k, v in variables_metadata.items():

if k in constants and v.get("constant_in_time") is not True:
v["constant_in_time"] = True
changed = True
LOG.info(f"Setting {k} constant_in_time to True")

if "is_constant_in_time" in v:
del v["is_constant_in_time"]
changed = True

if changed:
entry_set_value("/metadata/variables_metadata", variables_metadata)
entry_set_value("/metadata/updated", updated + 1)

if "variables_metadata" not in entry.record["metadata"] or args.force:
LOG.info("%s, setting `variables_metadata` πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯", name)

if args.dry_run:
LOG.info("Would set `variables_metadata` %s", name)
Expand All @@ -156,8 +192,7 @@ def entry_set_value(path, value):
try:
tmp = os.path.join(dir, "tmp.zarr")

c = creator_factory("init", config=path, path=tmp, overwrite=True)
c.run()
creator_factory("init", config=path, path=tmp, overwrite=True).run()

with open(f"{tmp}/.zattrs") as f:
variables_metadata = yaml.safe_load(f)["variables_metadata"]
Expand All @@ -169,6 +204,14 @@ def entry_set_value(path, value):
finally:
shutil.rmtree(dir)

if "constant_fields" not in entry.record["metadata"] or args.force:
LOG.info("%s, setting `constant_fields` πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯", name)
ds = open_dataset(name)
constant_fields = ds.computed_constant_fields()
LOG.info("%s", constant_fields)
entry_set_value("/metadata/constant_fields", constant_fields)
entry_set_value("/metadata/updated", updated + 1)

def zarr_file_from_catalogue(self, path, args):
import zarr

Expand Down

0 comments on commit 6d8cc2a

Please sign in to comment.