diff --git a/pcmdi_metrics/mean_climate/mean_climate_driver.py b/pcmdi_metrics/mean_climate/mean_climate_driver.py index 8fb71f6a1..156012485 100755 --- a/pcmdi_metrics/mean_climate/mean_climate_driver.py +++ b/pcmdi_metrics/mean_climate/mean_climate_driver.py @@ -327,58 +327,35 @@ print("region:", region) # land/sea mask -- conduct masking only for variable data array, not entire data - if ("land" in region.split("_")) or ( - "ocean" in region.split("_") + if any( + keyword in region.split("_") + for keyword in ["land", "ocean"] ): ds_test_tmp = ds_test.copy(deep=True) ds_ref_tmp = ds_ref.copy(deep=True) if "land" in region.split("_"): ds_test_tmp[varname] = apply_landmask( - ds_test, - data_var=varname, + ds_test[varname], landfrac=t_grid["sftlf"], - mask_land=False, - mask_ocean=True, + keep_over="land", ) ds_ref_tmp[varname] = apply_landmask( - ds_ref, - data_var=varname, + ds_ref[varname], landfrac=t_grid["sftlf"], - mask_land=False, - mask_ocean=True, + keep_over="land", ) - """ - ds_test_tmp[varname] = ds_test[varname].where( - t_grid["sftlf"] != 0.0 - ) - ds_ref_tmp[varname] = ds_ref[varname].where( - t_grid["sftlf"] != 0.0 - ) - """ elif "ocean" in region.split("_"): ds_test_tmp[varname] = apply_landmask( - ds_test, - data_var=varname, + ds_test[varname], landfrac=t_grid["sftlf"], - mask_land=True, - mask_ocean=False, + keep_over="ocean", ) ds_ref_tmp[varname] = apply_landmask( - ds_ref, - data_var=varname, + ds_ref[varname], landfrac=t_grid["sftlf"], - mask_land=True, - mask_ocean=False, - ) - """ - ds_test_tmp[varname] = ds_test[varname].where( - t_grid["sftlf"] == 0.0 - ) - ds_ref_tmp[varname] = ds_ref[varname].where( - t_grid["sftlf"] == 0.0 + keep_over="ocean", ) - """ - print("mask done") + print("mask done") else: ds_test_tmp = ds_test ds_ref_tmp = ds_ref diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index 956cd7c6d..a6758da23 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -1,15 +1,20 @@ +import warnings +from typing import Union + import regionmask import xarray as xr import xcdat as xc -def create_land_sea_mask(ds: xr.Dataset, as_boolean: bool = False) -> xr.DataArray: - """Generate a land-sea mask (1 for land, 0 for sea) for a given xarray Dataset. +def create_land_sea_mask( + obj: Union[xr.Dataset, xr.DataArray], as_boolean: bool = False +) -> xr.DataArray: + """Generate a land-sea mask (1 for land, 0 for sea) for a given xarray Dataset or DataArray. Parameters ---------- - ds : xr.Dataset - A Dataset object. + obj : Union[xr.Dataset, xr.DataArray] + The Dataset or DataArray object. as_boolean : bool, optional Set mask value to True (land) or False (ocean), by default False, thus 1 (land) and 0 (ocean). @@ -36,11 +41,11 @@ def create_land_sea_mask(ds: xr.Dataset, as_boolean: bool = False) -> xr.DataArr land_mask = regionmask.defined_regions.natural_earth_v5_0_0.land_110 # Get the longitude and latitude from the xarray dataset - key_lon = xc.axis.get_dim_keys(ds, axis="X") - key_lat = xc.axis.get_dim_keys(ds, axis="Y") + key_lon = xc.axis.get_dim_keys(obj, axis="X") + key_lat = xc.axis.get_dim_keys(obj, axis="Y") - lon = ds[key_lon] - lat = ds[key_lat] + lon = obj[key_lon] + lat = obj[key_lat] # Mask the land-sea mask to match the dataset's coordinates land_sea_mask = land_mask.mask(lon, lat) @@ -85,11 +90,10 @@ def find_min(da: xr.DataArray) -> float: def apply_landmask( - ds: xr.Dataset, - data_var: str, - landfrac: xr.DataArray, - mask_land: bool = True, - mask_ocean: bool = False, + obj: Union[xr.Dataset, xr.DataArray], + data_var: str = None, + landfrac: xr.DataArray = None, + keep_over: str = None, land_criteria: float = 0.8, ocean_criteria: float = 0.2, ) -> xr.DataArray: @@ -97,16 +101,14 @@ def apply_landmask( Parameters ---------- - ds : xr.Dataset - Dataset that includes a DataArray to apply a land-sea mask. - data_var : str - Name of DataArray in the Dataset. + obj : Union[xr.Dataset, xr.DataArray] + The Dataset or DataArray object to apply a land-sea mask. landfrac : xr.DataArray Data array for land fraction that consists of 0 for ocean and 1 for land (fraction for grid along coastline). - mask_land : bool, optional - Mask out land region (thus value will exist over ocean only), by default True. - mask_ocean : bool, optional - Mask out ocean region (thus value will exist over land only), by default False. + data_var : str + Name of DataArray in the Dataset, required if obs is an Dataset. + keep_over : str + Specify whether to keep values "land" or "ocean". land_criteria : float, optional When the fraction is equal to land_criteria or larger, the grid will be considered as land, by default 0.8. ocean_criteria : float, optional @@ -120,27 +122,54 @@ def apply_landmask( Examples -------- Import: - >>> from pcmdi_metrics.utils import apply_landmask - Mask over land (keep values over ocean only): - - >>> da_masked = apply_landmask(ds, data_var="ts", landfrac=mask, mask_land=True, mask_ocean=False) + Keep values over land only (mask over ocean): + >>> da_land = apply_landmask(da, landfrac=mask, keep_over="land") # use DataArray + >>> da_land = apply_landmask(ds, data_var="ts", landfrac=mask, keep_over="land") # use DataSet - Mask over ocean (keep values over land only): - - >>> da_masked = apply_landmask(ds, data_var="ts", landfrac=mask, mask_land=False, mask_ocean=True) + Keep values over ocean only (mask over land): + >>> da_ocean = apply_landmask(da, landfrac=mask, keep_over="ocean") # use DataArray + >>> da_ocean = apply_landmask(ds, data_var="ts", landfrac=mask, keep_over="ocean") # use DataSet """ - data_array = ds[data_var].copy() + + if isinstance(obj, xr.DataArray): + data_array = obj.copy() + elif isinstance(obj, xr.Dataset): + if data_var is None: + raise ValueError("Invalid value for data_var. Provide name of DataArray.") + else: + data_array = obj[data_var].copy() + + # Validate landfrac + if landfrac is None: + landfrac = create_land_sea_mask(data_array) + warnings.warn( + "landfrac is not provided thus generated using the 'create_land_sea_mask' function" + ) + + # Check units of landfrac + percentage = False + if find_min(landfrac) == 0 and find_max(landfrac) == 100: + percentage = True + if "units" in list(landfrac.attrs.keys()): + if landfrac.units == "%": + percentage = True # Convert landfrac to a fraction if it's in percentage form - if landfrac.units == "%" or (find_min(landfrac) == 0 and find_max(landfrac) == 100): + if percentage: landfrac /= 100.0 + # Validate keep_over parameter + if keep_over not in ["land", "ocean"]: + raise ValueError( + "Invalid value for keep_over. Choose either 'land' or 'ocean'." + ) + # Apply land and ocean masks - if mask_land: - data_array = data_array.where(landfrac <= ocean_criteria) - if mask_ocean: + if keep_over == "land": data_array = data_array.where(landfrac >= land_criteria) + elif keep_over == "ocean": + data_array = data_array.where(landfrac <= ocean_criteria) return data_array