Skip to content

Commit

Permalink
chore(qa): Fix docsigs
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Feb 10, 2025
1 parent 4b99ea6 commit b08a138
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 44 deletions.
6 changes: 1 addition & 5 deletions src/anemoi/datasets/create/functions/filters/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ def __repr__(self) -> str:


class RenamedFieldFormat:
"""Rename a field based on a format string.
Args:
format (str): A string that defines the new name of the field.
"""
"""Rename a field based on a format string."""

def __init__(self, field, what, format):
self.field = field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# nor does it submit to any jurisdiction.


from typing import Union

import numpy as np
from earthkit.data.indexing.fieldlist import FieldArray
from earthkit.meteo import constants
Expand Down Expand Up @@ -43,7 +45,7 @@ def __getattr__(self, name):
return getattr(self.field, name)


def model_level_pressure(A, B, surface_pressure):
def model_level_pressure(A, B, surface_pressure) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Calculates:
- pressure at the model full- and half-levels
- delta: depth of log(pressure) at full levels
Expand Down Expand Up @@ -114,7 +116,7 @@ def model_level_pressure(A, B, surface_pressure):
return p_full_level, p_half_level, delta, alpha


def calc_specific_gas_constant(q):
def calc_specific_gas_constant(q) -> Union[float, np.ndarray]:
"""Calculates the specific gas constant of moist air
(specific content of cloud particles and hydrometeors are neglected)
Expand All @@ -133,7 +135,7 @@ def calc_specific_gas_constant(q):
return R


def relative_geopotential_thickness(alpha, q, T):
def relative_geopotential_thickness(alpha, q, T) -> np.ndarray:
"""Calculates the geopotential thickness w.r.t the surface on model full-levels
Parameters
Expand All @@ -158,7 +160,7 @@ def relative_geopotential_thickness(alpha, q, T):
return dphi


def pressure_at_height_level(height, q, T, sp, A, B):
def pressure_at_height_level(height, q, T, sp, A, B) -> Union[float, np.ndarray]:
"""Calculates the pressure at a height level given in meters above surface.
This is done by finding the model level above and below the specified height
and interpolating the pressure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# nor does it submit to any jurisdiction.


from __future__ import annotations

import datetime
import logging

Expand Down Expand Up @@ -78,7 +80,7 @@ def __repr__(self):
self.variable.shape,
)

def reduced(self, i):
def reduced(self, i) -> Coordinate:
"""Create a new coordinate with a single value
Parameters
Expand All @@ -96,7 +98,7 @@ def reduced(self, i):
**self.kwargs,
)

def index(self, value):
def index(self, value) -> Coordinate:
"""Return the index of the value in the coordinate
Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,14 @@ def _skip_attr(v, attr_name):

return cls(ds, variables)

def sel(self, **kwargs):
def sel(self, **kwargs) -> FieldList:
"""Override the FieldList's sel method
Parameters
----------
kwargs : dict
The selection criteria
Returns
-------
FieldList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,7 @@ def __init__(self, variable, **kwargs):

@cached_property
def fields(self):
"""Filter the fields of a variable based on metadata.
Returns
-------
list
A list of fields that match the metadata.
"""
"""Filter the fields of a variable based on metadata."""
return [
field
for field in self.variable
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/datasets/create/statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
LOG = logging.getLogger(__name__)


def default_statistics_dates(dates):
def default_statistics_dates(dates) -> tuple:
"""Calculate default statistics dates based on the given list of dates.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/datasets/data/complement.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def missing(self):
missing = missing | self.target.missing
return set(missing)

def tree(self):
def tree(self) -> Node:
"""Generates a hierarchical tree structure for the `Cutout` instance and
its associated datasets.
Expand Down
8 changes: 6 additions & 2 deletions src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
import pprint
import warnings
from functools import cached_property
from typing import TYPE_CHECKING

import numpy as np
from anemoi.utils.dates import frequency_to_seconds
from anemoi.utils.dates import frequency_to_string
from anemoi.utils.dates import frequency_to_timedelta

if TYPE_CHECKING:
import matplotlib

LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -519,7 +523,7 @@ def _compute_constant_fields_from_statistics(self):

return result

def plot(self, date, variable, member=0, **kwargs):
def plot(self, date, variable, member=0, **kwargs) -> "matplotlib.pyplot.Axes":
"""For debugging purposes, plot a field.
Parameters
Expand All @@ -537,7 +541,7 @@ def plot(self, date, variable, member=0, **kwargs):
Returns
-------
matplotlib.pyplot.Axes
axes : matplotlib.pyplot.Axes
"""

from anemoi.utils.devtools import plot_values
Expand Down
30 changes: 9 additions & 21 deletions src/anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _initialize_masks(self):
lam_current_mask[~lam_overlap_mask] = False
self.masks.append(lam_current_mask)

def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0):
def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0) -> bool:
"""Checks for overlapping points between two sets of latitudes and
longitudes within a specified distance threshold.
Expand All @@ -261,7 +261,7 @@ def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0):
# Check if any distance is less than the specified threshold
return np.any(distances < distance_threshold)

def __getitem__(self, index):
def __getitem__(self, index) -> np.ndarray:
"""Retrieves data from the masked LAMs and global dataset based on the
given index.
Expand All @@ -276,7 +276,7 @@ def __getitem__(self, index):
index = (index, slice(None), slice(None), slice(None))
return self._get_tuple(index)

def _get_tuple(self, index):
def _get_tuple(self, index) -> np.ndarray:
"""Helper method that applies masks and retrieves data from each dataset
according to the specified index.
Expand All @@ -300,7 +300,7 @@ def _get_tuple(self, index):

return apply_index_to_slices_changes(result, changes)

def collect_supporting_arrays(self, collected, *path):
def collect_supporting_arrays(self, collected, *path) -> None:
"""Collects supporting arrays, including masks for each LAM and the global
dataset.
Expand All @@ -316,12 +316,9 @@ def collect_supporting_arrays(self, collected, *path):
collected.append((path + ("global",), "cutout_mask", self.global_mask))

@cached_property
def shape(self):
def shape(self) -> tuple:
"""Returns the shape of the Cutout, accounting for retained grid points
across all LAMs and the global dataset.
Returns:
tuple: Shape of the concatenated masked datasets.
"""
shapes = [np.sum(mask) for mask in self.masks]
global_shape = np.sum(self.global_mask)
Expand All @@ -333,24 +330,18 @@ def check_same_resolution(self, d1, d2):
pass

@property
def grids(self):
def grids(self) -> tuple:
"""Returns the number of grid points for each LAM and the global dataset
after applying masks.
Returns:
tuple: Count of retained grid points for each dataset.
"""
grids = [np.sum(mask) for mask in self.masks]
grids.append(np.sum(self.global_mask))
return tuple(grids)

@property
def latitudes(self):
def latitudes(self) -> np.ndarray:
"""Returns the concatenated latitudes of each LAM and the global dataset
after applying masks.
Returns:
np.ndarray: Concatenated latitude array for the masked datasets.
"""
lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)])

Expand All @@ -362,12 +353,9 @@ def latitudes(self):
return latitudes

@property
def longitudes(self):
def longitudes(self) -> np.ndarray:
"""Returns the concatenated longitudes of each LAM and the global dataset
after applying masks.
Returns:
np.ndarray: Concatenated longitude array for the masked datasets.
"""
lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)])

Expand All @@ -378,7 +366,7 @@ def longitudes(self):
longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]])
return longitudes

def tree(self):
def tree(self) -> Node:
"""Generates a hierarchical tree structure for the `Cutout` instance and
its associated datasets.
Expand Down

0 comments on commit b08a138

Please sign in to comment.