Skip to content

Commit

Permalink
better support for missing values in debug plots
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 1, 2025
1 parent 75d3bd9 commit ce6da54
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 39 deletions.
10 changes: 4 additions & 6 deletions src/anemoi/utils/devtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ def plot_values(
ax.add_feature(cfeature.BORDERS, linestyle=":")

missing_values = np.isnan(values)

if missing_value is None:
values = values[~missing_values]
longitudes = longitudes[~missing_values]
latitudes = latitudes[~missing_values]
else:
values = np.where(missing_values, missing_value, values)
min = np.nanmin(values)
missing_value = min - np.abs(min) * 0.001

values = np.where(missing_values, missing_value, values)

if max_value is not None:
values = np.where(values > max_value, max_value, values)
Expand Down
14 changes: 4 additions & 10 deletions src/anemoi/utils/humanize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def bytes_to_human(n: float) -> str:

if n < 0:
sign = "-"
n -= 0
n = -n
else:
sign = ""

Expand Down Expand Up @@ -326,15 +326,9 @@ def _(x):
if years > 1:
return _("%d years" % (years,))

month = then.month
if now.year != then.year:
month -= 12

d = abs(now.month - month)
if d >= 12:
return _("a year")
else:
return _("%d month%s" % (d, _plural(d)))
delta = abs(now - then)
if delta.days > 1 and delta.days < 30:
return _("%d days" % (delta.days,))

return "on %s %d %s %d" % (
DOW[then.weekday()],
Expand Down
62 changes: 39 additions & 23 deletions src/anemoi/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import sys
from functools import cached_property

import entrypoints

Expand Down Expand Up @@ -39,7 +40,7 @@ class Registry:
def __init__(self, package, key="_type"):

self.package = package
self.registered = {}
self._registered = {}
self.kind = package.split(".")[-1]
self.key = key
_BY_KIND[self.kind] = self
Expand All @@ -53,10 +54,25 @@ def register(self, name: str, factory: callable = None):
if factory is None:
return Wrapper(name, self)

self.registered[name] = factory
self._registered[name] = factory

# def registered(self, name: str):
# return name in self.registered
def names(self):

package = importlib.import_module(self.package)
root = os.path.dirname(package.__file__)
result = []

for file in os.listdir(root):
if file[0] == ".":
continue
if file == "__init__.py":
continue
if file.endswith(".py"):
result.append(file[:-3])
if os.path.isdir(os.path.join(root, file)):
if os.path.exists(os.path.join(root, file, "__init__.py")):
result.append(file)
return result

def _load(self, file):
name, _ = os.path.splitext(file)
Expand All @@ -67,9 +83,19 @@ def _load(self, file):

def lookup(self, name: str, *, return_none=False) -> callable:

# print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered)
if name in self.registered:
return self.registered[name]
if name not in self.registered:
if return_none:
return None

for e in self._registered:
LOG.info(f"Registered: {e}")

raise ValueError(f"Cannot load '{name}' from {self.package}")

return self.registered[name]

@cached_property
def registered(self):

directory = sys.modules[self.package].__path__[0]

Expand All @@ -92,23 +118,13 @@ def lookup(self, name: str, *, return_none=False) -> callable:

entrypoint_group = f"anemoi.{self.kind}"
for entry_point in entrypoints.get_group_all(entrypoint_group):
if entry_point.name == name:
if name in self.registered:
LOG.warning(
f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'"
)
self.registered[name] = entry_point.load()
if entry_point.name in self._registered:
LOG.warning(
f"Overwriting builtin '{entry_point.name}' from {self.package} with plugin '{entry_point.module_name}'"
)
self._registered[entry_point.name] = entry_point.load()

if name not in self.registered:
if return_none:
return None

for e in self.registered:
LOG.info(f"Registered: {e}")

raise ValueError(f"Cannot load '{name}' from {self.package}")

return self.registered[name]
return self._registered

def create(self, name: str, *args, **kwargs):
factory = self.lookup(name)
Expand Down

0 comments on commit ce6da54

Please sign in to comment.