diff --git a/xesmf/backend.py b/xesmf/backend.py index 27de084..1066292 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -189,7 +189,9 @@ def add_corner(grid, lon_b, lat_b): def esmf_regrid_build(sourcegrid, destgrid, method, - filename=None, extra_dims=None, ignore_degenerate=None): + filename=None, extra_dims=None, + extrap=None, extrap_exp=None, extrap_num_pnts=None, + ignore_degenerate=None): ''' Create an ESMF.Regrid object, containing regridding weights. @@ -215,7 +217,8 @@ def esmf_regrid_build(sourcegrid, destgrid, method, Offline weight file. **Require ESMPy 7.1.0.dev38 or newer.** With the weights available, we can use Scipy's sparse matrix mulplication to apply weights, which is faster and more Pythonic - than ESMPy's online regridding. + than ESMPy's online regridding. If None, weights are stored in + memory only. extra_dims : a list of integers, optional Extra dimensions (e.g. time or levels) in the data field @@ -227,6 +230,20 @@ def esmf_regrid_build(sourcegrid, destgrid, method, For example, if extra_dims=[Nlev, Ntime], then the data field dimension will be [Nlon, Nlat, Nlev, Ntime] + extrap : str, optional + Extrapolation method. Options are + + - 'inverse_dist' + - 'nearest_s2d' + + extrap_exp : float, optional + The exponent to raise the distance to when calculating weights for the + extrapolation method. If none are specified, defaults to 2.0 + + extrap_num_pnts : int, optional + The number of source points to use for the extrapolation methods + that use more than one source point. If none are specified, defaults to 8 + ignore_degenerate : bool, optional If False (default), raise error if grids contain degenerated cells (i.e. triangles or lines, instead of quadrilaterals) @@ -250,6 +267,22 @@ def esmf_regrid_build(sourcegrid, destgrid, method, raise ValueError('method should be chosen from ' '{}'.format(list(method_dict.keys()))) + # use shorter, clearer names for options in ESMF.ExtrapMethod + extrap_dict = {'inverse_dist': ESMF.ExtrapMethod.NEAREST_IDAVG, + 'nearest_s2d': ESMF.ExtrapMethod.NEAREST_STOD, + None: None + } + try: + esmf_extrap_method = extrap_dict[extrap] + except: + raise ValueError('method should be chosen from ' + '{}'.format(list(extrap_dict.keys()))) + + # until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible + # if files are written on disk + if (extrap is not None) & (filename is not None): + raise ValueError('extrap cannot be used aongside a filename.') + # conservative regridding needs cell corner information if method == 'conservative': if not sourcegrid.has_corners: @@ -275,9 +308,9 @@ def esmf_regrid_build(sourcegrid, destgrid, method, # Must set unmapped_action to IGNORE, otherwise the function will fail, # if the destination grid is larger than the source grid. regrid = ESMF.Regrid(sourcefield, destfield, filename=filename, - regrid_method=esmf_regrid_method, - unmapped_action=ESMF.UnmappedAction.IGNORE, - ignore_degenerate=ignore_degenerate) + regrid_method=esmf_regrid_method, extrap_method=esmf_extrap_method, extrap_dist_exponent=extrap_exp, + extrap_num_src_pnts=extrap_num_pnts, src_mask_values=[0], dst_mask_values=[0], unmapped_action=ESMF.UnmappedAction.IGNORE, + ignore_degenerate=ignore_degenerate, factors=True) return regrid diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 096a1e4..c060956 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -62,6 +62,11 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): # tranpose the arrays so they become Fortran-ordered grid = esmf_grid(lon.T, lat.T, periodic=periodic) + # detect ds["mask"] and add it to the grid + if 'mask' in ds.data_vars: + import ESMF + grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER) + grid.mask[0][...] = np.asarray(ds['mask']).T if need_bounds: lon_b = np.asarray(ds['lon_b']) @@ -104,7 +109,8 @@ def ds_to_ESMFlocstream(ds): class Regridder(object): def __init__(self, ds_in, ds_out, method, periodic=False, - filename=None, reuse_weights=False, ignore_degenerate=None, + extrap=None, extrap_exp=None, extrap_num_pnts=None, + weights=None, ignore_degenerate=None, locstream_in=False, locstream_out=False): """ Make xESMF regridder @@ -113,8 +119,8 @@ def __init__(self, ds_in, ds_out, method, periodic=False, ---------- ds_in, ds_out : xarray DataSet, or dictionary Contain input and output grid coordinates. Look for variables - ``lon``, ``lat``, and optionally ``lon_b``, ``lat_b`` for - conservative method. + ``lon``, ``lat``, optionally ``lon_b``, ``lat_b`` for + conservative method, and ``mask``. Use 0 to identify cells to mask. Shape can be 1D (n_lon,) and (n_lat,) for rectilinear grids, or 2D (n_y, n_x) for general curvilinear grids. @@ -134,16 +140,27 @@ def __init__(self, ds_in, ds_out, method, periodic=False, Only useful for global grids with non-conservative regridding. Will be forced to False for conservative regridding. - filename : str, optional - Name for the weight file. The default naming scheme is:: + extrap : str, optional + Extrapolation method. Options are - {method}_{Ny_in}x{Nx_in}_{Ny_out}x{Nx_out}.nc + - 'inverse_dist' + - 'nearest_s2d' + + extrap_exp : float, optional + The exponent to raise the distance to when calculating weights for the + extrapolation method. If none are specified, defaults to 2.0 - e.g. bilinear_400x600_300x400.nc + extrap_num_pnts : int, optional + The number of source points to use for the extrapolation methods + that use more than one source point. If none are specified, defaults to 8 - reuse_weights : bool, optional - Whether to read existing weight file to save computing time. - False by default (i.e. re-compute, not reuse). + weights : None, coo_matrix, dict, str, Dataset, Path, + Regridding weights, stored as + - a scipy.sparse COO matrix, + - a dictionary with keys `row_dst`, `col_src` and `weights`, + - an xarray Dataset with data variables `col`, `row` and `S`, + - or a path to a netCDF file created by ESMF. + If None, compute the weights. ignore_degenerate : bool, optional If False (default), raise error if grids contain degenerated cells @@ -170,7 +187,9 @@ def __init__(self, ds_in, ds_out, method, periodic=False, self.method = method self.periodic = periodic - self.reuse_weights = reuse_weights + self.extrap = extrap + self.extrap_exp = extrap_exp + self.extrap_num_pnts = extrap_num_pnts self.ignore_degenerate = ignore_degenerate self.locstream_in = locstream_in self.locstream_out = locstream_out @@ -226,14 +245,11 @@ def __init__(self, ds_in, ds_out, method, periodic=False, self.n_in = shape_in[0] * shape_in[1] self.n_out = shape_out[0] * shape_out[1] - if filename is None: - self.filename = self._get_default_filename() - else: - self.filename = filename + if weights is None: + weights = self._compute_weights() # Dictionary of weights - # get weight matrix - self._write_weight_file() - self.weights = read_weights(self.filename, self.n_in, self.n_out) + # Convert weights, whatever their format, to a sparse coo matrix + self.weights = read_weights(weights, self.n_in, self.n_out) @property def A(self): @@ -261,49 +277,24 @@ def _get_default_filename(self): return filename - def _write_weight_file(self): - - if os.path.exists(self.filename): - if self.reuse_weights: - print('Reuse existing file: {}'.format(self.filename)) - return # do not compute it again, just read it - else: - print('Overwrite existing file: {} \n'.format(self.filename), - 'You can set reuse_weights=True to save computing time.') - os.remove(self.filename) - else: - print('Create weight file: {}'.format(self.filename)) - + def _compute_weights(self): regrid = esmf_regrid_build(self._grid_in, self._grid_out, self.method, - filename=self.filename, + extrap = self.extrap, extrap_exp = self.extrap_exp, + extrap_num_pnts = self.extrap_num_pnts, ignore_degenerate=self.ignore_degenerate) - esmf_regrid_finalize(regrid) # only need weights, not regrid object - def clean_weight_file(self): - """ - Remove the offline weight file on disk. - - To save the time on re-computing weights, you can just keep the file, - and set "reuse_weights=True" when initializing the regridder next time. - """ - if os.path.exists(self.filename): - print("Remove file {}".format(self.filename)) - os.remove(self.filename) - else: - print("File {} is already removed.".format(self.filename)) + w = regrid.get_weights_dict(deep_copy=True) + esmf_regrid_finalize(regrid) # only need weights, not regrid object + return w def __repr__(self): info = ('xESMF Regridder \n' 'Regridding algorithm: {} \n' - 'Weight filename: {} \n' - 'Reuse pre-computed weights? {} \n' 'Input grid shape: {} \n' 'Output grid shape: {} \n' 'Output grid dimension name: {} \n' 'Periodic in longitude? {}' .format(self.method, - self.filename, - self.reuse_weights, self.shape_in, self.shape_out, self.out_horiz_dims, @@ -509,3 +500,13 @@ def regrid_dataset(self, ds_in, keep_attrs=False): ds_out = ds_out.squeeze(dim='dummy') return ds_out + + def to_netcdf(self, filename=None): + '''Save weights to disk as a netCDF file.''' + if filename is None: + filename = self._get_default_filename() + w = self.weights + ds = xr.Dataset({"S": w.data, "col": w.col + 1, "row": w.row + 1}) + ds.to_netcdf(filename) + return filename + diff --git a/xesmf/smm.py b/xesmf/smm.py index 55a3c90..0a02280 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -5,9 +5,10 @@ import xarray as xr import scipy.sparse as sps import warnings +from pathlib import Path -def read_weights(filename, n_in, n_out): +def read_weights(weights, n_in, n_out): ''' Read regridding weights into a scipy sparse COO matrix. @@ -31,14 +32,22 @@ def read_weights(filename, n_in, n_out): A : scipy sparse COO matrix. ''' - ds_w = xr.open_dataset(filename) - - col = ds_w['col'].values - 1 # Python starts with 0 - row = ds_w['row'].values - 1 - S = ds_w['S'].values - - weights = sps.coo_matrix((S, (row, col)), shape=[n_out, n_in]) - return weights + if isinstance(weights, (str, Path, xr.Dataset)): + if not isinstance(weights, xr.Dataset): + ds_w = xr.open_dataset(weights) + col = ds_w['col'].values - 1 # Python starts with 0 + row = ds_w['row'].values - 1 + S = ds_w['S'].values + + elif isinstance(weights, dict): + col = weights['col_src'] - 1 + row = weights['row_dst'] - 1 + S = weights['weights'] + + elif isinstance(weights, sps.coo_matrix): + return weights + + return sps.coo_matrix((S, (row, col)), shape=[n_out, n_in]) def apply_weights(weights, indata, shape_in, shape_out): diff --git a/xesmf/tests/test_backend.py b/xesmf/tests/test_backend.py index c7cf848..1c4eaee 100644 --- a/xesmf/tests/test_backend.py +++ b/xesmf/tests/test_backend.py @@ -112,6 +112,22 @@ def test_esmf_build_bilinear(): esmf_regrid_finalize(regrid) +def test_esmf_extrapolation(): + + grid_in = esmf_grid(lon_in.T, lat_in.T) + grid_out = esmf_grid(lon_out.T, lat_out.T) + + regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear') + data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T + # without extrapolation, the first and last lines/columns = 0 + assert data_out_esmpy[0, 0] == 0 + + regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear', extrap='inverse_dist', extrap_num_pnts=3, extrap_exp=1) + data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T + # the 3 closest points in data_in are 2.010, 2.005, and 1.992. The result should be roughly equal to 2.0 + assert np.round(data_out_esmpy[0, 0], 1) == 2.0 + + def test_regrid(): # use conservative regridding as an example, diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 32d3b7d..386acba 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -50,7 +50,7 @@ def test_as_2d_mesh(): methods_list = ['bilinear', 'conservative', 'nearest_s2d', 'nearest_d2s'] @pytest.mark.parametrize("locstream_in,locstream_out,method", [ - (False, False, 'conservative'), + (False, False, 'conservative'), (False, False, 'bilinear'), (False, True, 'bilinear'), (False, False, 'nearest_s2d'), @@ -75,27 +75,22 @@ def test_build_regridder(method, locstream_in, locstream_out): assert 'xESMF Regridder' in str(regridder) assert method in str(regridder) - regridder.clean_weight_file() - def test_existing_weights(): # the first run method = 'bilinear' regridder = xe.Regridder(ds_in, ds_out, method) + fn = regridder.to_netcdf() # make sure we can reuse weights - assert os.path.exists(regridder.filename) + assert os.path.exists(fn) regridder_reuse = xe.Regridder(ds_in, ds_out, method, - reuse_weights=True) + weights=fn) assert regridder_reuse.A.shape == regridder.A.shape # or can also overwrite it xe.Regridder(ds_in, ds_out, method) - # clean-up - regridder.clean_weight_file() - assert not os.path.exists(regridder.filename) - def test_conservative_without_bounds(): with pytest.raises(KeyError): @@ -110,7 +105,6 @@ def test_build_regridder_from_dict(): regridder = xe.Regridder({'lon': lon_in, 'lat': lat_in}, {'lon': lon_out, 'lat': lat_out}, 'bilinear') - regridder.clean_weight_file() def test_regrid_periodic_wrong(): @@ -123,9 +117,6 @@ def test_regrid_periodic_wrong(): rel_err = (ds_out['data_ref'] - dr_out)/ds_out['data_ref'] assert np.max(np.abs(rel_err)) == 1.0 # some data will be missing - # clean-up - regridder.clean_weight_file() - def test_regrid_periodic_correct(): regridder = xe.Regridder(ds_in, ds_out, 'bilinear', periodic=True) @@ -136,9 +127,6 @@ def test_regrid_periodic_correct(): rel_err = (ds_out['data_ref'] - dr_out)/ds_out['data_ref'] assert np.max(np.abs(rel_err)) < 0.065 - # clean-up - regridder.clean_weight_file() - def ds_2d_to_1d(ds): ds_temp = ds.reset_coords() @@ -164,9 +152,6 @@ def test_regrid_with_1d_grid(): assert_equal(dr_out['lon'].values, ds_out_1d['lon'].values) assert_equal(dr_out['lat'].values, ds_out_1d['lat'].values) - # clean-up - regridder.clean_weight_file() - # TODO: consolidate (regrid method, input data types) combination # using pytest fixtures and parameterization @@ -202,9 +187,6 @@ def test_regrid_dataarray(): xr.testing.assert_identical(dr_out_4D['time'], ds_in['time']) xr.testing.assert_identical(dr_out_4D['lev'], ds_in['lev']) - # clean-up - regridder.clean_weight_file() - def test_regrid_dataarray_to_locstream(): # xarray.DataArray containing in-memory numpy array @@ -217,9 +199,6 @@ def test_regrid_dataarray_to_locstream(): # DataArray and numpy array should lead to the same result assert_equal(outdata.squeeze(), dr_out.values) - # clean-up - regridder.clean_weight_file() - with pytest.raises(ValueError): regridder = xe.Regridder(ds_in, ds_locs, 'conservative', locstream_out=True) @@ -235,9 +214,6 @@ def test_regrid_dataarray_from_locstream(): # DataArray and numpy array should lead to the same result assert_equal(outdata, dr_out.values) - # clean-up - regridder.clean_weight_file() - with pytest.raises(ValueError): regridder = xe.Regridder(ds_locs, ds_in, 'bilinear', locstream_in=True) with pytest.raises(ValueError): @@ -262,9 +238,6 @@ def test_regrid_dask(): rel_err = (outdata.compute() - outdata_ref) / outdata_ref assert np.max(np.abs(rel_err)) < 0.05 - # clean-up - regridder.clean_weight_file() - def test_regrid_dask_to_locstream(): # chunked dask array (no xarray metadata) @@ -274,19 +247,13 @@ def test_regrid_dask_to_locstream(): indata = ds_in_chunked['data4D'].data outdata = regridder(indata) - # clean-up - regridder.clean_weight_file() - def test_regrid_dask_from_locstream(): # chunked dask array (no xarray metadata) regridder = xe.Regridder(ds_locs, ds_in, 'nearest_s2d', locstream_in=True) - outdata = regridder(ds_locs['lat'].data) - - # clean-up - regridder.clean_weight_file() + outdata = regridder(ds_locs['lat'].data) def test_regrid_dataarray_dask(): @@ -311,9 +278,6 @@ def test_regrid_dataarray_dask(): assert_equal(dr_out['lat'].values, ds_out['lat'].values) assert_equal(dr_out['lon'].values, ds_out['lon'].values) - # clean-up - regridder.clean_weight_file() - def test_regrid_dataarray_dask_to_locstream(): # xarray.DataArray containing chunked dask array @@ -323,19 +287,13 @@ def test_regrid_dataarray_dask_to_locstream(): dr_in = ds_in_chunked['data4D'] dr_out = regridder(dr_in) - # clean-up - regridder.clean_weight_file() - def test_regrid_dataarray_dask_from_locstream(): # xarray.DataArray containing chunked dask array regridder = xe.Regridder(ds_locs, ds_in, 'nearest_s2d', locstream_in=True) - outdata = regridder(ds_locs['lat']) - - # clean-up - regridder.clean_weight_file() + outdata = regridder(ds_locs['lat']) def test_regrid_dataset(): @@ -365,17 +323,12 @@ def test_regrid_dataset(): assert_equal(ds_result['lat'].values, ds_out['lat'].values) assert_equal(ds_result['lon'].values, ds_out['lon'].values) - # clean-up - regridder.clean_weight_file() - def test_regrid_dataset_to_locstream(): # xarray.Dataset containing in-memory numpy array regridder = xe.Regridder(ds_in, ds_locs, 'bilinear', locstream_out=True) ds_result = regridder(ds_in) - # clean-up - regridder.clean_weight_file() def test_regrid_dataset_from_locstream(): @@ -383,8 +336,6 @@ def test_regrid_dataset_from_locstream(): regridder = xe.Regridder(ds_locs, ds_in, 'nearest_s2d', locstream_in=True) outdata = regridder(ds_locs) - # clean-up - regridder.clean_weight_file() def test_ds_to_ESMFlocstream():