From ce6da549dcb025dde132e984c6b513213abccd37 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 1 Feb 2025 19:28:07 +0000 Subject: [PATCH] better support for missing values in debug plots --- src/anemoi/utils/devtools.py | 10 +++--- src/anemoi/utils/humanize.py | 14 +++----- src/anemoi/utils/registry.py | 62 +++++++++++++++++++++++------------- 3 files changed, 47 insertions(+), 39 deletions(-) diff --git a/src/anemoi/utils/devtools.py b/src/anemoi/utils/devtools.py index e79b486..233a3fc 100644 --- a/src/anemoi/utils/devtools.py +++ b/src/anemoi/utils/devtools.py @@ -36,13 +36,11 @@ def plot_values( ax.add_feature(cfeature.BORDERS, linestyle=":") missing_values = np.isnan(values) - if missing_value is None: - values = values[~missing_values] - longitudes = longitudes[~missing_values] - latitudes = latitudes[~missing_values] - else: - values = np.where(missing_values, missing_value, values) + min = np.nanmin(values) + missing_value = min - np.abs(min) * 0.001 + + values = np.where(missing_values, missing_value, values) if max_value is not None: values = np.where(values > max_value, max_value, values) diff --git a/src/anemoi/utils/humanize.py b/src/anemoi/utils/humanize.py index da81e2b..7603c35 100644 --- a/src/anemoi/utils/humanize.py +++ b/src/anemoi/utils/humanize.py @@ -47,7 +47,7 @@ def bytes_to_human(n: float) -> str: if n < 0: sign = "-" - n -= 0 + n = -n else: sign = "" @@ -326,15 +326,9 @@ def _(x): if years > 1: return _("%d years" % (years,)) - month = then.month - if now.year != then.year: - month -= 12 - - d = abs(now.month - month) - if d >= 12: - return _("a year") - else: - return _("%d month%s" % (d, _plural(d))) + delta = abs(now - then) + if delta.days > 1 and delta.days < 30: + return _("%d days" % (delta.days,)) return "on %s %d %s %d" % ( DOW[then.weekday()], diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 5ab88e0..cce8617 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -12,6 +12,7 @@ import logging import os import sys +from functools import cached_property import entrypoints @@ -39,7 +40,7 @@ class Registry: def __init__(self, package, key="_type"): self.package = package - self.registered = {} + self._registered = {} self.kind = package.split(".")[-1] self.key = key _BY_KIND[self.kind] = self @@ -53,10 +54,25 @@ def register(self, name: str, factory: callable = None): if factory is None: return Wrapper(name, self) - self.registered[name] = factory + self._registered[name] = factory - # def registered(self, name: str): - # return name in self.registered + def names(self): + + package = importlib.import_module(self.package) + root = os.path.dirname(package.__file__) + result = [] + + for file in os.listdir(root): + if file[0] == ".": + continue + if file == "__init__.py": + continue + if file.endswith(".py"): + result.append(file[:-3]) + if os.path.isdir(os.path.join(root, file)): + if os.path.exists(os.path.join(root, file, "__init__.py")): + result.append(file) + return result def _load(self, file): name, _ = os.path.splitext(file) @@ -67,9 +83,19 @@ def _load(self, file): def lookup(self, name: str, *, return_none=False) -> callable: - # print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered) - if name in self.registered: - return self.registered[name] + if name not in self.registered: + if return_none: + return None + + for e in self._registered: + LOG.info(f"Registered: {e}") + + raise ValueError(f"Cannot load '{name}' from {self.package}") + + return self.registered[name] + + @cached_property + def registered(self): directory = sys.modules[self.package].__path__[0] @@ -92,23 +118,13 @@ def lookup(self, name: str, *, return_none=False) -> callable: entrypoint_group = f"anemoi.{self.kind}" for entry_point in entrypoints.get_group_all(entrypoint_group): - if entry_point.name == name: - if name in self.registered: - LOG.warning( - f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'" - ) - self.registered[name] = entry_point.load() + if entry_point.name in self._registered: + LOG.warning( + f"Overwriting builtin '{entry_point.name}' from {self.package} with plugin '{entry_point.module_name}'" + ) + self._registered[entry_point.name] = entry_point.load() - if name not in self.registered: - if return_none: - return None - - for e in self.registered: - LOG.info(f"Registered: {e}") - - raise ValueError(f"Cannot load '{name}' from {self.package}") - - return self.registered[name] + return self._registered def create(self, name: str, *args, **kwargs): factory = self.lookup(name)