Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR]: Add Z axis support for spatial averaging #606

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr

from tests import requires_dask
from tests.fixtures import generate_dataset
from tests.fixtures import generate_dataset, generate_lev_dataset
from xcdat.spatial import SpatialAccessor


Expand Down Expand Up @@ -45,6 +45,35 @@ def test_raises_error_if_data_var_not_in_dataset(self):
with pytest.raises(KeyError):
self.ds.spatial.average("not_a_data_var", axis=["Y", "incorrect_axis"])

def test_vertical_average_with_weights(self):
# check that vertical averaging returns the correct answer
# get dataset with vertical levels
ds = generate_lev_dataset()
# subset to one column for testing (and shake up data)
ds = ds.isel(time=[0], lat=[0], lon=[0]).squeeze()
so = ds["so"]
so[:] = np.array([1, 2, 3, 4])
ds["so"] = so
Comment on lines +54 to +56
Copy link
Collaborator

@tomvothecoder tomvothecoder Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
so = ds["so"]
so[:] = np.array([1, 2, 3, 4])
ds["so"] = so
ds["so"].values = np.array([1, 2, 3, 4])

Assigning the numpy array directly to the DataArray will work too

result = ds.spatial.average(
"so", lev_bounds=(4000, 10000), axis=["Z"], keep_weights=True
)
# specify expected result
expected = xr.DataArray(
data=np.array(1.8), coords={"time": ds.time, "lat": ds.lat, "lon": ds.lon}
)
# compare
xr.testing.assert_allclose(result["so"], expected)

# check that vertical averaging returns the correct weights
expected = xr.DataArray(
data=np.array([2000, 2000, 1000, 0.0]),
coords={"time": ds.time, "lev": ds.lev, "lat": ds.lat, "lon": ds.lon},
dims=["lev"],
attrs={"xcdat_bounds": True},
)

xr.testing.assert_allclose(result["lev_wts"], expected)
Comment on lines +65 to +75
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are two-in-one tests permitted?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah there's no problem with that here. As long as the tests are relatively easy to maintain.


def test_raises_error_if_axis_list_contains_unsupported_axis(self):
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["Y", "incorrect_axis"])
Expand Down Expand Up @@ -313,6 +342,23 @@ def test_raises_error_if_dataset_has_multiple_bounds_variables_for_an_axis(self)
with pytest.raises(TypeError):
ds.spatial.get_weights(axis=["Y", "X"])

def test_vertical_weighting(self):
# get dataset with vertical coordinate
ds = generate_lev_dataset()
# call _get_vertical_weights
result = ds.spatial._get_vertical_weights(
domain_bounds=ds.lev_bnds, region_bounds=np.array([4000, 10000])
)
# specify expected result
expected = xr.DataArray(
data=np.array([2000, 2000, 1000, 0.0]),
coords={"lev": ds.lev},
dims=["lev"],
attrs={"units": "m", "positive": "down", "axis": "Z", "bounds": "lev_bnds"},
)
# compare
xr.testing.assert_allclose(result, expected)

def test_data_var_weights_for_region_in_lat_and_lon_domains(self):
ds = self.ds.copy()

Expand Down
81 changes: 70 additions & 11 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#: Type alias for a dictionary of axis keys mapped to their bounds.
AxisWeights = Dict[Hashable, xr.DataArray]
#: Type alias for supported spatial axis keys.
SpatialAxis = Literal["X", "Y"]
SpatialAxis = Literal["X", "Y", "Z"]
SPATIAL_AXES: Tuple[SpatialAxis, ...] = get_args(SpatialAxis)
#: Type alias for a tuple of floats/ints for the regional selection bounds.
RegionAxisBounds = Tuple[float, float]
Expand Down Expand Up @@ -73,10 +73,12 @@ def average(
keep_weights: bool = False,
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
lev_bounds: Optional[RegionAxisBounds] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think lev_bounds is a good generic name, but I'm open to other possibilities.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me

) -> xr.Dataset:
"""
Calculates the spatial average for a rectilinear grid over an optionally
specified regional domain.
Calculates the weighted spatial and/or vertical average for a
rectilinear grid over an optionally specified regional and/or vertical
domain.
tomvothecoder marked this conversation as resolved.
Show resolved Hide resolved

Operations include:

Expand All @@ -101,7 +103,7 @@ def average(
average.
axis : List[SpatialAxis]
List of axis dimensions to average over, by default ["X", "Y"].
Valid axis keys include "X" and "Y".
Valid axis keys include "X", "Y", and "Z".
weights : {"generate", xr.DataArray}, optional
If "generate", then weights are generated. Otherwise, pass a
DataArray containing the regional weights used for weighted
Expand All @@ -122,6 +124,10 @@ def average(
ignored if ``weights`` are supplied. The lower bound can be larger
than the upper bound (e.g., across the prime meridian, dateline), by
default None.
lev_bounds : Optional[RegionAxisBounds], optional
A tuple of floats/ints for the regional lower and upper level
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The default is None.

Returns
-------
Expand All @@ -143,11 +149,15 @@ def average(
>>>
>>> ds.lon.attrs["axis"]
>>> X
>>>
>>> ds.level.attrs["axis"]
>>> Z

Set the 'axis' attribute for the required coordinates if it isn't:

>>> ds.lat.attrs["axis"] = "Y"
>>> ds.lon.attrs["axis"] = "X"
>>> ds.level.attrs["axis"] = "Z"

Call spatial averaging method:

Expand All @@ -167,6 +177,10 @@ def average(

>>> ts_zonal = ds.spatial.average("tas", axis=["X"])["tas"]

Get the vertical average (between 100 and 1000 hPa):

>>> ta_column = ds.spatial.average("ta", axis=["Z"], lev_bounds=(100, 1000))["ta"]

Using custom weights for averaging:

>>> # The shape of the weights must align with the data var.
Expand All @@ -178,6 +192,13 @@ def average(
>>>
>>> ts_global = ds.spatial.average("tas", axis=["X", "Y"],
>>> weights=weights)["tas"]

Notes:
------
Weights are generally computed as the difference between the bounds. If
sub-selecting a region, the units must match the axis units (e.g.,
Pa/hPa or m/km). The sub-selected region must be in numerical order
(e.g., (100, 1000) and not (1000, 100)).
"""
ds = self._dataset.copy()
dv = _get_data_var(ds, data_var)
Expand All @@ -188,7 +209,11 @@ def average(
self._validate_region_bounds("Y", lat_bounds)
if lon_bounds is not None:
self._validate_region_bounds("X", lon_bounds)
self._weights = self.get_weights(axis, lat_bounds, lon_bounds, data_var)
if lev_bounds is not None:
self._validate_region_bounds("Z", lev_bounds)
self._weights = self.get_weights(
axis, lat_bounds, lon_bounds, lev_bounds, data_var
)
elif isinstance(weights, xr.DataArray):
self._weights = weights

Expand All @@ -205,6 +230,7 @@ def get_weights(
axis: List[SpatialAxis],
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
lev_bounds: Optional[RegionAxisBounds] = None,
data_var: Optional[str] = None,
) -> xr.DataArray:
"""
Expand All @@ -216,9 +242,9 @@ def get_weights(
weights are then combined to form a DataArray of weights that can be
used to perform a weighted (spatial) average.

If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells
outside this selected regional domain are given zero weight. Grid cells
that are partially in this domain are given partial weight.
If ``lat_bounds``, ``lon_bounds``, or ``lev_bounds`` are supplied, then
grid cells outside this selected regional domain are given zero weight.
Grid cells that are partially in this domain are given partial weight.

Parameters
----------
Expand All @@ -230,6 +256,9 @@ def get_weights(
lon_bounds : Optional[RegionAxisBounds]
Tuple of longitude boundaries for regional selection, by default
None.
lev_bounds : Optional[RegionAxisBounds]
Tuple of level boundaries for vertical selection, by default
None.
data_var: Optional[str]
The key of the data variable, by default None. Pass this argument
when the dataset has more than one bounds per axis (e.g., "lon"
Expand All @@ -246,9 +275,7 @@ def get_weights(
Notes
-----
This method was developed for rectilinear grids only. ``get_weights()``
recognizes and operate on latitude and longitude, but could be extended
to work with other standard geophysical dimensions (e.g., time, depth,
and pressure).
recognizes and operate on latitude, longitude, and vertical levels.
"""
Bounds = TypedDict(
"Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]}
Expand All @@ -267,6 +294,12 @@ def get_weights(
if lat_bounds is not None
else None,
},
"Z": {
"weights_method": self._get_vertical_weights,
"region": np.array(lev_bounds, dtype="float")
if lev_bounds is not None
else None,
},
}

axis_weights: AxisWeights = {}
Expand Down Expand Up @@ -476,6 +509,32 @@ def _get_latitude_weights(
weights = self._calculate_weights(d_bounds)
return weights

def _get_vertical_weights(
self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
) -> xr.DataArray:
"""Gets weights for the vertical axis.

This method scales the domain to a region (if selected) and returns weights
proportional to the difference between each pair of level bounds.

Parameters
----------
domain_bounds : xr.DataArray
The array of bounds for the vertical domain.
region_bounds : Optional[np.ndarray]
The array of bounds for vertical selection.

Returns
-------
xr.DataArray
The vertical axis weights.
"""
if region_bounds is not None:
domain_bounds = self._scale_domain_to_region(domain_bounds, region_bounds)

weights = self._calculate_weights(domain_bounds)
return weights

Comment on lines +512 to +537
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is sufficiently different from _get_longitude_weights (which deals with the prime meridian) and _get_latitude_weights (sine of bounds) to justify its own function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

def _calculate_weights(self, domain_bounds: xr.DataArray):
"""Calculate weights for the domain.

Expand Down
Loading