Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Land sea mask generation #1006

Merged
merged 11 commits into from
Dec 18, 2023
165 changes: 56 additions & 109 deletions pcmdi_metrics/mean_climate/mean_climate_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
from collections import OrderedDict
from re import split

import cdms2
import cdutil
import numpy as np
import xcdat as xc

from pcmdi_metrics import resources
from pcmdi_metrics.io import load_regions_specs, region_subset
from pcmdi_metrics.mean_climate.lib import (
Expand All @@ -19,8 +14,11 @@
load_and_regrid,
mean_climate_metrics_to_json,
)
from pcmdi_metrics.utils import apply_landmask, create_land_sea_mask, create_target_grid
from pcmdi_metrics.variability_mode.lib import sort_human, tree

print("--- prepare mean climate metrics calculation ---")

parser = create_mean_climate_parser()
parameter = parser.get_parameter(argparse_vals_only=False)

Expand Down Expand Up @@ -78,103 +76,45 @@
regions_specs = load_regions_specs()

default_regions = ["global", "NHEX", "SHEX", "TROPICS"]
print(
"case_id: ",
case_id,
"\n",
"test_data_set:",
test_data_set,
"\n",
"realization:",
realization,
"\n",
"vars:",
vars,
"\n",
"varname_in_test_data:",
varname_in_test_data,
"\n",
"reference_data_set:",
reference_data_set,
"\n",
"target_grid:",
target_grid,
"\n",
"regrid_tool:",
regrid_tool,
"\n",
"regrid_tool_ocn:",
regrid_tool_ocn,
"\n",
"save_test_clims:",
save_test_clims,
"\n",
"test_clims_interpolated_output:",
test_clims_interpolated_output,
"\n",
"filename_template:",
filename_template,
"\n",
"sftlf_filename_template:",
sftlf_filename_template,
"\n",
"generate_sftlf:",
generate_sftlf,
"\n",
"regions_specs:",
regions_specs,
"\n",
"regions:",
regions,
"\n",
"test_data_path:",
test_data_path,
"\n",
"reference_data_path:",
reference_data_path,
"\n",
"metrics_output_path:",
metrics_output_path,
"\n",
"diagnostics_output_path:",
diagnostics_output_path,
"\n",
"debug:",
debug,
"\n",

config_variables = OrderedDict(
[
("case_id", case_id),
("test_data_set", test_data_set),
("realization", realization),
("vars", vars),
("varname_in_test_data", varname_in_test_data),
("reference_data_set", reference_data_set),
("target_grid", target_grid),
("regrid_tool", regrid_tool),
("regrid_tool_ocn", regrid_tool_ocn),
("save_test_clims", save_test_clims),
("test_clims_interpolated_output", test_clims_interpolated_output),
("filename_template", filename_template),
("sftlf_filename_template", sftlf_filename_template),
("generate_sftlf", generate_sftlf),
("regions_specs", regions_specs),
("regions", regions),
("test_data_path", test_data_path),
("reference_data_path", reference_data_path),
("metrics_output_path", metrics_output_path),
("diagnostics_output_path", diagnostics_output_path),
("debug", debug),
]
)

print("--- prepare mean climate metrics calculation ---")
for key, value in config_variables.items():
print(f"{key}: {value}")

# generate target grid
res = target_grid.split("x")
lat_res = float(res[0])
lon_res = float(res[1])
start_lat = -90.0 + lat_res / 2
start_lon = 0.0
end_lat = 90.0 - lat_res / 2
end_lon = 360.0 - lon_res
nlat = ((end_lat - start_lat) * 1.0 / lat_res) + 1
nlon = ((end_lon - start_lon) * 1.0 / lon_res) + 1
t_grid = xc.create_uniform_grid(
start_lat, end_lat, lat_res, start_lon, end_lon, lon_res
)
if debug:
print(
"type(t_grid):", type(t_grid)
) # Expected type is 'xarray.core.dataset.Dataset'
print("t_grid:", t_grid)
# identical target grid in cdms2 to use generateLandSeaMask function that is yet to exist in xcdat
t_grid_cdms2 = cdms2.createUniformGrid(
start_lat, nlat, lat_res, start_lon, nlon, lon_res
)
t_grid = create_target_grid(target_grid_resolution=target_grid)

# generate land sea mask for the target grid
sft = cdutil.generateLandSeaMask(t_grid_cdms2)
if debug:
print("sft:", sft)
print("sft.getAxisList():", sft.getAxisList())
sft = create_land_sea_mask(t_grid)

# add sft to target grid dataset
t_grid["sftlf"] = (["lat", "lon"], np.array(sft))
t_grid["sftlf"] = sft

if debug:
print("t_grid (after sftlf added):", t_grid)
t_grid.to_netcdf("target_grid.nc")
Expand All @@ -188,8 +128,6 @@
obs_file_path = os.path.join(egg_pth, obs_file_name)
with open(obs_file_path) as fo:
obs_dict = json.loads(fo.read())
# if debug:
# print('obs_dict:', json.dumps(obs_dict, indent=4, sort_keys=True))

print("--- start mean climate metrics calculation ---")

Expand Down Expand Up @@ -353,26 +291,35 @@
print("region:", region)

# land/sea mask -- conduct masking only for variable data array, not entire data
if ("land" in region.split("_")) or (
"ocean" in region.split("_")
if any(
keyword in region.split("_")
for keyword in ["land", "ocean"]
):
ds_test_tmp = ds_test.copy(deep=True)
ds_ref_tmp = ds_ref.copy(deep=True)
if "land" in region.split("_"):
ds_test_tmp[varname] = ds_test[varname].where(
t_grid["sftlf"] != 0.0
ds_test_tmp[varname] = apply_landmask(
ds_test[varname],
landfrac=t_grid["sftlf"],
keep_over="land",
)
ds_ref_tmp[varname] = ds_ref[varname].where(
t_grid["sftlf"] != 0.0
ds_ref_tmp[varname] = apply_landmask(
ds_ref[varname],
landfrac=t_grid["sftlf"],
keep_over="land",
)
elif "ocean" in region.split("_"):
ds_test_tmp[varname] = ds_test[varname].where(
t_grid["sftlf"] == 0.0
ds_test_tmp[varname] = apply_landmask(
ds_test[varname],
landfrac=t_grid["sftlf"],
keep_over="ocean",
)
ds_ref_tmp[varname] = ds_ref[varname].where(
t_grid["sftlf"] == 0.0
ds_ref_tmp[varname] = apply_landmask(
ds_ref[varname],
landfrac=t_grid["sftlf"],
keep_over="ocean",
)
print("mask done")
print("mask done")
else:
ds_test_tmp = ds_test
ds_ref_tmp = ds_ref
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import sys

import cdutil
import numpy as np
import pandas as pd
import xarray as xr
Expand All @@ -12,7 +11,8 @@
from scipy.stats import chi2
from xcdat.regridder import grid

import pcmdi_metrics
from pcmdi_metrics.io.base import Base
from pcmdi_metrics.utils import create_land_sea_mask


# ==================================================================================
Expand Down Expand Up @@ -94,9 +94,7 @@ def precip_variability_across_timescale(
outfilename = (
"PS_pr." + str(dfrq) + "_regrid.180x90_area.freq.mean_" + dat + ".json"
)
JSON = pcmdi_metrics.io.base.Base(
outdir.replace("%(output_type)", "metrics_results"), outfilename
)
JSON = Base(outdir.replace("%(output_type)", "metrics_results"), outfilename)
JSON.write(
psdmfm,
json_structure=["model+realization", "variability type", "domain", "frequency"],
Expand Down Expand Up @@ -389,9 +387,8 @@ def Avg_PS_DomFrq(d, frequency, ntd, dat, mip, frc):
else:
sys.exit("ERROR: frc " + frc + " is not defined!")

d_cdms = xr.DataArray.to_cdms2(d[0])
mask = cdutil.generateLandSeaMask(d_cdms)
mask = xr.DataArray.from_cdms2(mask)
# generate land sea mask
mask = create_land_sea_mask(d[0])

psdmfm = {}
for dom in domains:
Expand All @@ -405,8 +402,8 @@ def Avg_PS_DomFrq(d, frequency, ntd, dat, mip, frc):
dmask = d

dmask = dmask.to_dataset(name="ps")
dmask = dmask.bounds.add_bounds(axis="X", width=0.5)
dmask = dmask.bounds.add_bounds(axis="Y", width=0.5)
dmask = dmask.bounds.add_bounds(axis="X")
dmask = dmask.bounds.add_bounds(axis="Y")

if "50S50N" in dom:
am = dmask.sel(lat=slice(-50, 50)).spatial.average(
Expand Down
3 changes: 3 additions & 0 deletions pcmdi_metrics/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .create_land_sea_mask import apply_landmask, create_land_sea_mask
from .create_target_grid import create_target_grid
from .sort_human import sort_human
Loading
Loading