From 4fa2158049c4c9d3955faf3e0949650e74cd1595 Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Sun, 4 Feb 2024 14:35:05 -0800 Subject: [PATCH 1/2] Initial work on #596 --- xcdat/spatial.py | 80 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 2c50595a..c46593eb 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -29,7 +29,7 @@ #: Type alias for a dictionary of axis keys mapped to their bounds. AxisWeights = Dict[Hashable, xr.DataArray] #: Type alias for supported spatial axis keys. -SpatialAxis = Literal["X", "Y"] +SpatialAxis = Literal["X", "Y", "Z"] SPATIAL_AXES: Tuple[SpatialAxis, ...] = get_args(SpatialAxis) #: Type alias for a tuple of floats/ints for the regional selection bounds. RegionAxisBounds = Tuple[float, float] @@ -73,10 +73,12 @@ def average( keep_weights: bool = False, lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, + lev_bounds: Optional[RegionAxisBounds] = None, ) -> xr.Dataset: """ - Calculates the spatial average for a rectilinear grid over an optionally - specified regional domain. + Calculates the weighted spatial and/or vertical average for a + rectilinear grid over an optionally specified regional and/or vertical + domain. Operations include: @@ -101,7 +103,7 @@ def average( average. axis : List[SpatialAxis] List of axis dimensions to average over, by default ["X", "Y"]. - Valid axis keys include "X" and "Y". + Valid axis keys include "X", "Y", and "Z". weights : {"generate", xr.DataArray}, optional If "generate", then weights are generated. Otherwise, pass a DataArray containing the regional weights used for weighted @@ -122,6 +124,10 @@ def average( ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. + lev_bounds : Optional[RegionAxisBounds], optional + A tuple of floats/ints for the regional lower and upper level + boundaries. This arg is used when calculating axis weights, but is + ignored if ``weights`` are supplied. The default is None. Returns ------- @@ -143,11 +149,15 @@ def average( >>> >>> ds.lon.attrs["axis"] >>> X + >>> + >>> ds.level.attrs["axis"] + >>> Z Set the 'axis' attribute for the required coordinates if it isn't: >>> ds.lat.attrs["axis"] = "Y" >>> ds.lon.attrs["axis"] = "X" + >>> ds.level.attrs["axis"] = "Z" Call spatial averaging method: @@ -167,6 +177,10 @@ def average( >>> ts_zonal = ds.spatial.average("tas", axis=["X"])["tas"] + Get the vertical average (between 100 and 1000 hPa): + + >>> ta_column = ds.spatial.average("ta", axis=["Z"], lev_bounds=(100, 1000))["ta"] + Using custom weights for averaging: >>> # The shape of the weights must align with the data var. @@ -178,6 +192,12 @@ def average( >>> >>> ts_global = ds.spatial.average("tas", axis=["X", "Y"], >>> weights=weights)["tas"] + + Notes: + ------ + Weights are generally computed as the difference between the bounds. If + sub-selecting a region, the units must match the axis units (e.g., + Pa/hPa or m/km). """ ds = self._dataset.copy() dv = _get_data_var(ds, data_var) @@ -188,7 +208,11 @@ def average( self._validate_region_bounds("Y", lat_bounds) if lon_bounds is not None: self._validate_region_bounds("X", lon_bounds) - self._weights = self.get_weights(axis, lat_bounds, lon_bounds, data_var) + if lev_bounds is not None: + self._validate_region_bounds("Z", lev_bounds) + self._weights = self.get_weights( + axis, lat_bounds, lon_bounds, lev_bounds, data_var + ) elif isinstance(weights, xr.DataArray): self._weights = weights @@ -205,6 +229,7 @@ def get_weights( axis: List[SpatialAxis], lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, + lev_bounds: Optional[RegionAxisBounds] = None, data_var: Optional[str] = None, ) -> xr.DataArray: """ @@ -216,9 +241,9 @@ def get_weights( weights are then combined to form a DataArray of weights that can be used to perform a weighted (spatial) average. - If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells - outside this selected regional domain are given zero weight. Grid cells - that are partially in this domain are given partial weight. + If ``lat_bounds``, ``lon_bounds``, or ``lev_bounds`` are supplied, then + grid cells outside this selected regional domain are given zero weight. + Grid cells that are partially in this domain are given partial weight. Parameters ---------- @@ -230,6 +255,9 @@ def get_weights( lon_bounds : Optional[RegionAxisBounds] Tuple of longitude boundaries for regional selection, by default None. + lev_bounds : Optional[RegionAxisBounds] + Tuple of level boundaries for vertical selection, by default + None. data_var: Optional[str] The key of the data variable, by default None. Pass this argument when the dataset has more than one bounds per axis (e.g., "lon" @@ -246,9 +274,7 @@ def get_weights( Notes ----- This method was developed for rectilinear grids only. ``get_weights()`` - recognizes and operate on latitude and longitude, but could be extended - to work with other standard geophysical dimensions (e.g., time, depth, - and pressure). + recognizes and operate on latitude, longitude, and vertical levels. """ Bounds = TypedDict( "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]} @@ -267,6 +293,12 @@ def get_weights( if lat_bounds is not None else None, }, + "Z": { + "weights_method": self._get_vertical_weights, + "region": np.array(lev_bounds, dtype="float") + if lev_bounds is not None + else None, + }, } axis_weights: AxisWeights = {} @@ -476,6 +508,32 @@ def _get_latitude_weights( weights = self._calculate_weights(d_bounds) return weights + def _get_vertical_weights( + self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + ) -> xr.DataArray: + """Gets weights for the vertical axis. + + This method scales the domain to a region (if selected) and returns weights + proportional to the difference between each pair of level bounds. + + Parameters + ---------- + domain_bounds : xr.DataArray + The array of bounds for the vertical domain. + region_bounds : Optional[np.ndarray] + The array of bounds for vertical selection. + + Returns + ------- + xr.DataArray + The vertical axis weights. + """ + if region_bounds is not None: + domain_bounds = self._scale_domain_to_region(domain_bounds, region_bounds) + + weights = self._calculate_weights(domain_bounds) + return weights + def _calculate_weights(self, domain_bounds: xr.DataArray): """Calculate weights for the domain. From 98ea3d77cc369bbc5535dff033c6a8dd89b77be2 Mon Sep 17 00:00:00 2001 From: Stephen Po-Chedley Date: Tue, 13 Feb 2024 18:13:21 -0800 Subject: [PATCH 2/2] add unit tests for vertical averaging --- tests/test_spatial.py | 48 ++++++++++++++++++++++++++++++++++++++++++- xcdat/spatial.py | 3 ++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index fe0361cd..3c19a5fd 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -3,7 +3,7 @@ import xarray as xr from tests import requires_dask -from tests.fixtures import generate_dataset +from tests.fixtures import generate_dataset, generate_lev_dataset from xcdat.spatial import SpatialAccessor @@ -45,6 +45,35 @@ def test_raises_error_if_data_var_not_in_dataset(self): with pytest.raises(KeyError): self.ds.spatial.average("not_a_data_var", axis=["Y", "incorrect_axis"]) + def test_vertical_average_with_weights(self): + # check that vertical averaging returns the correct answer + # get dataset with vertical levels + ds = generate_lev_dataset() + # subset to one column for testing (and shake up data) + ds = ds.isel(time=[0], lat=[0], lon=[0]).squeeze() + so = ds["so"] + so[:] = np.array([1, 2, 3, 4]) + ds["so"] = so + result = ds.spatial.average( + "so", lev_bounds=(4000, 10000), axis=["Z"], keep_weights=True + ) + # specify expected result + expected = xr.DataArray( + data=np.array(1.8), coords={"time": ds.time, "lat": ds.lat, "lon": ds.lon} + ) + # compare + xr.testing.assert_allclose(result["so"], expected) + + # check that vertical averaging returns the correct weights + expected = xr.DataArray( + data=np.array([2000, 2000, 1000, 0.0]), + coords={"time": ds.time, "lev": ds.lev, "lat": ds.lat, "lon": ds.lon}, + dims=["lev"], + attrs={"xcdat_bounds": True}, + ) + + xr.testing.assert_allclose(result["lev_wts"], expected) + def test_raises_error_if_axis_list_contains_unsupported_axis(self): with pytest.raises(ValueError): self.ds.spatial.average("ts", axis=["Y", "incorrect_axis"]) @@ -313,6 +342,23 @@ def test_raises_error_if_dataset_has_multiple_bounds_variables_for_an_axis(self) with pytest.raises(TypeError): ds.spatial.get_weights(axis=["Y", "X"]) + def test_vertical_weighting(self): + # get dataset with vertical coordinate + ds = generate_lev_dataset() + # call _get_vertical_weights + result = ds.spatial._get_vertical_weights( + domain_bounds=ds.lev_bnds, region_bounds=np.array([4000, 10000]) + ) + # specify expected result + expected = xr.DataArray( + data=np.array([2000, 2000, 1000, 0.0]), + coords={"lev": ds.lev}, + dims=["lev"], + attrs={"units": "m", "positive": "down", "axis": "Z", "bounds": "lev_bnds"}, + ) + # compare + xr.testing.assert_allclose(result, expected) + def test_data_var_weights_for_region_in_lat_and_lon_domains(self): ds = self.ds.copy() diff --git a/xcdat/spatial.py b/xcdat/spatial.py index c46593eb..94ae06e3 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -197,7 +197,8 @@ def average( ------ Weights are generally computed as the difference between the bounds. If sub-selecting a region, the units must match the axis units (e.g., - Pa/hPa or m/km). + Pa/hPa or m/km). The sub-selected region must be in numerical order + (e.g., (100, 1000) and not (1000, 100)). """ ds = self._dataset.copy() dv = _get_data_var(ds, data_var)