Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed Dec 17, 2023
1 parent 69883ef commit 33c7df9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 68 deletions.
47 changes: 12 additions & 35 deletions pcmdi_metrics/mean_climate/mean_climate_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,58 +327,35 @@
print("region:", region)

# land/sea mask -- conduct masking only for variable data array, not entire data
if ("land" in region.split("_")) or (
"ocean" in region.split("_")
if any(
keyword in region.split("_")
for keyword in ["land", "ocean"]
):
ds_test_tmp = ds_test.copy(deep=True)
ds_ref_tmp = ds_ref.copy(deep=True)
if "land" in region.split("_"):
ds_test_tmp[varname] = apply_landmask(
ds_test,
data_var=varname,
ds_test[varname],
landfrac=t_grid["sftlf"],
mask_land=False,
mask_ocean=True,
keep_over="land",
)
ds_ref_tmp[varname] = apply_landmask(
ds_ref,
data_var=varname,
ds_ref[varname],
landfrac=t_grid["sftlf"],
mask_land=False,
mask_ocean=True,
keep_over="land",
)
"""
ds_test_tmp[varname] = ds_test[varname].where(
t_grid["sftlf"] != 0.0
)
ds_ref_tmp[varname] = ds_ref[varname].where(
t_grid["sftlf"] != 0.0
)
"""
elif "ocean" in region.split("_"):
ds_test_tmp[varname] = apply_landmask(
ds_test,
data_var=varname,
ds_test[varname],
landfrac=t_grid["sftlf"],
mask_land=True,
mask_ocean=False,
keep_over="ocean",
)
ds_ref_tmp[varname] = apply_landmask(
ds_ref,
data_var=varname,
ds_ref[varname],
landfrac=t_grid["sftlf"],
mask_land=True,
mask_ocean=False,
)
"""
ds_test_tmp[varname] = ds_test[varname].where(
t_grid["sftlf"] == 0.0
)
ds_ref_tmp[varname] = ds_ref[varname].where(
t_grid["sftlf"] == 0.0
keep_over="ocean",
)
"""
print("mask done")
print("mask done")
else:
ds_test_tmp = ds_test
ds_ref_tmp = ds_ref
Expand Down
95 changes: 62 additions & 33 deletions pcmdi_metrics/utils/create_land_sea_mask.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import warnings
from typing import Union

import regionmask
import xarray as xr
import xcdat as xc


def create_land_sea_mask(ds: xr.Dataset, as_boolean: bool = False) -> xr.DataArray:
"""Generate a land-sea mask (1 for land, 0 for sea) for a given xarray Dataset.
def create_land_sea_mask(
obj: Union[xr.Dataset, xr.DataArray], as_boolean: bool = False
) -> xr.DataArray:
"""Generate a land-sea mask (1 for land, 0 for sea) for a given xarray Dataset or DataArray.
Parameters
----------
ds : xr.Dataset
A Dataset object.
obj : Union[xr.Dataset, xr.DataArray]
The Dataset or DataArray object.
as_boolean : bool, optional
Set mask value to True (land) or False (ocean), by default False, thus 1 (land) and 0 (ocean).
Expand All @@ -36,11 +41,11 @@ def create_land_sea_mask(ds: xr.Dataset, as_boolean: bool = False) -> xr.DataArr
land_mask = regionmask.defined_regions.natural_earth_v5_0_0.land_110

# Get the longitude and latitude from the xarray dataset
key_lon = xc.axis.get_dim_keys(ds, axis="X")
key_lat = xc.axis.get_dim_keys(ds, axis="Y")
key_lon = xc.axis.get_dim_keys(obj, axis="X")
key_lat = xc.axis.get_dim_keys(obj, axis="Y")

lon = ds[key_lon]
lat = ds[key_lat]
lon = obj[key_lon]
lat = obj[key_lat]

# Mask the land-sea mask to match the dataset's coordinates
land_sea_mask = land_mask.mask(lon, lat)
Expand Down Expand Up @@ -85,28 +90,25 @@ def find_min(da: xr.DataArray) -> float:


def apply_landmask(
ds: xr.Dataset,
data_var: str,
landfrac: xr.DataArray,
mask_land: bool = True,
mask_ocean: bool = False,
obj: Union[xr.Dataset, xr.DataArray],
data_var: str = None,
landfrac: xr.DataArray = None,
keep_over: str = None,
land_criteria: float = 0.8,
ocean_criteria: float = 0.2,
) -> xr.DataArray:
"""Apply a land-sea mask to a given DataArray in an xarray Dataset.
Parameters
----------
ds : xr.Dataset
Dataset that includes a DataArray to apply a land-sea mask.
data_var : str
Name of DataArray in the Dataset.
obj : Union[xr.Dataset, xr.DataArray]
The Dataset or DataArray object to apply a land-sea mask.
landfrac : xr.DataArray
Data array for land fraction that consists of 0 for ocean and 1 for land (fraction for grid along coastline).
mask_land : bool, optional
Mask out land region (thus value will exist over ocean only), by default True.
mask_ocean : bool, optional
Mask out ocean region (thus value will exist over land only), by default False.
data_var : str
Name of DataArray in the Dataset, required if obs is an Dataset.
keep_over : str
Specify whether to keep values "land" or "ocean".
land_criteria : float, optional
When the fraction is equal to land_criteria or larger, the grid will be considered as land, by default 0.8.
ocean_criteria : float, optional
Expand All @@ -120,27 +122,54 @@ def apply_landmask(
Examples
--------
Import:
>>> from pcmdi_metrics.utils import apply_landmask
Mask over land (keep values over ocean only):
>>> da_masked = apply_landmask(ds, data_var="ts", landfrac=mask, mask_land=True, mask_ocean=False)
Keep values over land only (mask over ocean):
>>> da_land = apply_landmask(da, landfrac=mask, keep_over="land") # use DataArray
>>> da_land = apply_landmask(ds, data_var="ts", landfrac=mask, keep_over="land") # use DataSet
Mask over ocean (keep values over land only):
>>> da_masked = apply_landmask(ds, data_var="ts", landfrac=mask, mask_land=False, mask_ocean=True)
Keep values over ocean only (mask over land):
>>> da_ocean = apply_landmask(da, landfrac=mask, keep_over="ocean") # use DataArray
>>> da_ocean = apply_landmask(ds, data_var="ts", landfrac=mask, keep_over="ocean") # use DataSet
"""
data_array = ds[data_var].copy()

if isinstance(obj, xr.DataArray):
data_array = obj.copy()
elif isinstance(obj, xr.Dataset):
if data_var is None:
raise ValueError("Invalid value for data_var. Provide name of DataArray.")
else:
data_array = obj[data_var].copy()

# Validate landfrac
if landfrac is None:
landfrac = create_land_sea_mask(data_array)
warnings.warn(
"landfrac is not provided thus generated using the 'create_land_sea_mask' function"
)

# Check units of landfrac
percentage = False
if find_min(landfrac) == 0 and find_max(landfrac) == 100:
percentage = True
if "units" in list(landfrac.attrs.keys()):
if landfrac.units == "%":
percentage = True

# Convert landfrac to a fraction if it's in percentage form
if landfrac.units == "%" or (find_min(landfrac) == 0 and find_max(landfrac) == 100):
if percentage:
landfrac /= 100.0

# Validate keep_over parameter
if keep_over not in ["land", "ocean"]:
raise ValueError(
"Invalid value for keep_over. Choose either 'land' or 'ocean'."
)

# Apply land and ocean masks
if mask_land:
data_array = data_array.where(landfrac <= ocean_criteria)
if mask_ocean:
if keep_over == "land":
data_array = data_array.where(landfrac >= land_criteria)
elif keep_over == "ocean":
data_array = data_array.where(landfrac <= ocean_criteria)

return data_array

0 comments on commit 33c7df9

Please sign in to comment.