Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Oct 8, 2024
1 parent 86a247d commit 379135d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
11 changes: 7 additions & 4 deletions tests/test_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self):
def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object_for_non_cf_axis(
self,
):
# Can only map to "lat" dim name
ds = xr.Dataset(
coords={
"lat": xr.DataArray(
Expand All @@ -354,18 +355,20 @@ def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object_for_non_cf_axi
"bounds": "lat_bnds",
},
),
"lat2": xr.DataArray(
"latitude": xr.DataArray(
data=np.ones(3),
dims="lat2",
dims="latitude",
attrs={
"bounds": "lat2_bnds",
"bounds": "latitude_bnds",
},
),
},
data_vars={
"var": xr.DataArray(data=np.ones(3), dims=["lat"]),
"lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]),
"lat2_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat2", "bnds"]),
"latitude_bnds": xr.DataArray(
data=np.ones((3, 3)), dims=["latitude", "bnds"]
),
},
)

Expand Down
49 changes: 36 additions & 13 deletions xcdat/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,19 +502,42 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]:
def _get_bounds_from_attr(
self, obj: xr.DataArray | xr.Dataset, axis: CFAxisKey
) -> List[str]:
bounds_keys = []
coords = get_dim_coords(obj, axis)

if isinstance(coords, xr.DataArray):
bnds_key = coords.attrs.get("bounds")
if bnds_key is not None:
bounds_keys.append(bnds_key)
elif isinstance(coords, xr.Dataset):
for coord in coords.values():
bnds_key = coord.attrs.get("bounds")

if bnds_key is not None:
bounds_keys.append(bnds_key)
"""Retrieve bounds attribute keys from the given xarray object.
This method extracts the "bounds" attribute keys from the coordinates
of the specified axis in the provided xarray DataArray or Dataset.
Parameters:
-----------
obj : xr.DataArray | xr.Dataset
The xarray object from which to retrieve the bounds attribute keys.
axis : CFAxisKey
The CF axis key ("X", "Y", "T", or "Z").
Returns:
--------
List[str]
A list of bounds attribute keys found in the coordinates of the
specified axis. Otherwise, an empty list is returned.
"""
coords_obj = get_dim_coords(obj, axis)
bounds_keys: List[str] = []

if isinstance(coords_obj, xr.DataArray):
bounds_keys = self._extract_bounds_key(coords_obj, bounds_keys)
elif isinstance(coords_obj, xr.Dataset):
for coord in coords_obj.coords.values():
bounds_keys = self._extract_bounds_key(coord, bounds_keys)

return bounds_keys

def _extract_bounds_key(
self, coords_obj: xr.DataArray, bounds_keys: List[str]
) -> List[str]:
bnds_key = coords_obj.attrs.get("bounds")

if bnds_key is not None:
bounds_keys.append(bnds_key)

return bounds_keys

Expand Down

0 comments on commit 379135d

Please sign in to comment.