Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed Jan 4, 2024
1 parent 932b33d commit e25b49e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pcmdi_metrics/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .base import MV2Json # noqa
from .default_regions_define import load_regions_specs # noqa
from .default_regions_define import region_subset # noqa
from .xcdat_xarray_dataset_io import ( # noqa
from .xcdat_dataset_io import ( # noqa
get_axis_list,
get_latitude_bounds_key,
get_latitude_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@
import xarray as xr
import xcdat as xc

# Internal function


def _find_key(
ds: Union[xr.Dataset, xr.DataArray], axis: str, potential_names: list
) -> str:
try:
key = xc.get_dim_keys(ds, axis)
except Exception:
axes = get_axis_list(ds)
key_candidates = [k for k in axes if k.lower() in potential_names]
if len(key_candidates) > 0:
key = key_candidates[0]
else:
key_candidates = [k for k in axes if k.lower() in potential_names]
if len(key_candidates) > 0:
key = key_candidates[0]
return key


# Retrieve coordinate key names


Expand All @@ -11,31 +31,26 @@ def get_axis_list(ds: Union[xr.Dataset, xr.DataArray]) -> list[str]:
return axes


def get_data_list(ds: Union[xr.Dataset, xr.DataArray]) -> list[str]:
return list(ds.data_vars.keys())


def get_time_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
try:
time_key = xc.get_dim_keys(ds, "T")
except Exception:
axes = get_axis_list(ds)
time_key = [k for k in axes if k.lower() in ["time"]][0]
return time_key
axis = "T"
potential_names = ["time", "t"]
return _find_key(ds, axis, potential_names)


def get_latitude_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
try:
lat_key = xc.get_dim_keys(ds, "Y")
except Exception:
axes = get_axis_list(ds)
lat_key = [k for k in axes if k.lower() in ["lat", "latitude"]][0]
return lat_key
axis = "Y"
potential_names = ["lat", "latitude"]
return _find_key(ds, axis, potential_names)


def get_longitude_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
try:
lon_key = xc.get_dim_keys(ds, "X")
except Exception:
axes = get_axis_list(ds)
lon_key = [k for k in axes if k.lower() in ["lon", "longitude"]][0]
return lon_key
axis = "X"
potential_names = ["lon", "longitude"]
return _find_key(ds, axis, potential_names)


# Retrieve bounds key names
Expand Down

0 comments on commit e25b49e

Please sign in to comment.