diff --git a/xesmf/backend.py b/xesmf/backend.py index 2fa8d96..7e257bc 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -52,7 +52,7 @@ def warn_lat_range(lat): warnings.warn("Latitude is outside of [-90, 90]") -def esmf_grid(lon, lat, periodic=False): +def esmf_grid(lon, lat, periodic=False, mask=None): ''' Create an ESMF.Grid object, for contrusting ESMF.Field and ESMF.Regrid @@ -70,6 +70,10 @@ def esmf_grid(lon, lat, periodic=False): Periodic in longitude? Default to False. Only useful for source grid. + mask : 2D numpy array, optional + Grid mask. Follows SCRIP convention where 1 is unmasked and 0 is + masked. + Returns ------- grid : ESMF.Grid object @@ -111,6 +115,20 @@ def esmf_grid(lon, lat, periodic=False): lon_pointer[...] = lon lat_pointer[...] = lat + # Follows SCRIP convention where 1 is unmasked and 0 is masked. + # See https://github.com/NCPP/ocgis/blob/61d88c60e9070215f28c1317221c2e074f8fb145/src/ocgis/regrid/base.py#L391-L404 + if mask is not None: + grid_mask = np.swapaxes(mask.astype(np.int32), 0, 1) + grid_mask = np.where(grid_mask == 0, 0, 1) + if not (grid_mask.shape == lon.shape): + raise ValueError( + "mask must have the same shape as the latitude/longitude" + "coordinates, got: mask.shape = %s, lon.shape = %s" % + (mask.shape, lon.shape)) + grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, + from_file=False) + grid.mask[0][:] = grid_mask + return grid @@ -175,6 +193,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method, - 'bilinear' - 'conservative', **need grid corner information** + - 'conservative_normed', **need grid corner information** - 'patch' - 'nearest_s2d' - 'nearest_d2s' @@ -204,6 +223,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method, # use shorter, clearer names for options in ESMF.RegridMethod method_dict = {'bilinear': ESMF.RegridMethod.BILINEAR, 'conservative': ESMF.RegridMethod.CONSERVE, + 'conservative_normed': ESMF.RegridMethod.CONSERVE, 'patch': ESMF.RegridMethod.PATCH, 'nearest_s2d': ESMF.RegridMethod.NEAREST_STOD, 'nearest_d2s': ESMF.RegridMethod.NEAREST_DTOS @@ -215,7 +235,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method, '{}'.format(list(method_dict.keys()))) # conservative regridding needs cell corner information - if method == 'conservative': + if method in ['conservative', 'conservative_normed']: if not sourcegrid.has_corners: raise ValueError('source grid has no corner information. ' 'cannot use conservative regridding.') @@ -235,12 +255,21 @@ def esmf_regrid_build(sourcegrid, destgrid, method, assert not os.path.exists(filename), ( 'Weight file already exists! Please remove it or use a new name.') + # re-normalize conservative regridding results + # https://github.com/JiaweiZhuang/xESMF/issues/17 + if method == 'conservative_normed': + norm_type = ESMF.NormType.FRACAREA + else: + norm_type = ESMF.NormType.DSTAREA + # Calculate regridding weights. # Must set unmapped_action to IGNORE, otherwise the function will fail, # if the destination grid is larger than the source grid. regrid = ESMF.Regrid(sourcefield, destfield, filename=filename, regrid_method=esmf_regrid_method, - unmapped_action=ESMF.UnmappedAction.IGNORE) + unmapped_action=ESMF.UnmappedAction.IGNORE, + norm_type=norm_type, + src_mask_values=[0], dst_mask_values=[0]) return regrid diff --git a/xesmf/frontend.py b/xesmf/frontend.py index e2ced7d..aee87d0 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -54,8 +54,14 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): lat = np.asarray(ds['lat']) lon, lat = as_2d_mesh(lon, lat) + if 'mask' in ds: + mask = np.asarray(ds['mask']) + print(mask.shape) + else: + mask = None + # tranpose the arrays so they become Fortran-ordered - grid = esmf_grid(lon.T, lat.T, periodic=periodic) + grid = esmf_grid(lon.T, lat.T, periodic=periodic, mask=mask) if need_bounds: lon_b = np.asarray(ds['lon_b']) @@ -77,17 +83,21 @@ def __init__(self, ds_in, ds_out, method, periodic=False, ds_in, ds_out : xarray DataSet, or dictionary Contain input and output grid coordinates. Look for variables ``lon``, ``lat``, and optionally ``lon_b``, ``lat_b`` for - conservative method. + add method. Shape can be 1D (Nlon,) and (Nlat,) for rectilinear grids, or 2D (Ny, Nx) for general curvilinear grids. Shape of bounds should be (N+1,) or (Ny+1, Nx+1). + If either dataset includes a 2d mask variable, that will also be + used to inform the regridding. + method : str Regridding method. Options are - 'bilinear' - 'conservative', **need grid corner information** + - 'conservative_normed', **need grid corner information** - 'patch' - 'nearest_s2d' - 'nearest_d2s' @@ -115,7 +125,7 @@ def __init__(self, ds_in, ds_out, method, periodic=False, """ # record basic switches - if method == 'conservative': + if method in ['conservative', 'conservative_normed']: self.need_bounds = True periodic = False # bound shape will not be N+1 for periodic grid else: diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 5b9d151..1609ebc 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -170,3 +170,20 @@ def test_regrid_with_1d_grid(): # clean-up regridder.clean_weight_file() + + +def test_build_regridder_with_masks(): + ds_in['mask'] = xr.DataArray( + np.random.randint(2, size=ds_in['data'].shape), + dims=('y', 'x')) + print(ds_in) + # 'patch' is too slow to test + for method in ['bilinear', 'conservative', 'nearest_s2d', 'nearest_d2s']: + regridder = xe.Regridder(ds_in, ds_out, method) + + # check screen output + assert repr(regridder) == str(regridder) + assert 'xESMF Regridder' in str(regridder) + assert method in str(regridder) + + regridder.clean_weight_file()