Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extrapolation and masking #97

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def add_corner(grid, lon_b, lat_b):


def esmf_regrid_build(sourcegrid, destgrid, method,
filename=None, extra_dims=None, ignore_degenerate=None):
filename=None, extra_dims=None,
extrap=None, extrap_exp=None, extrap_num_pnts=None,
ignore_degenerate=None):
'''
Create an ESMF.Regrid object, containing regridding weights.

Expand All @@ -215,7 +217,8 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
Offline weight file. **Require ESMPy 7.1.0.dev38 or newer.**
With the weights available, we can use Scipy's sparse matrix
mulplication to apply weights, which is faster and more Pythonic
than ESMPy's online regridding.
than ESMPy's online regridding. If None, weights are stored in
memory only.

extra_dims : a list of integers, optional
Extra dimensions (e.g. time or levels) in the data field
Expand All @@ -227,6 +230,20 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
For example, if extra_dims=[Nlev, Ntime], then the data field dimension
will be [Nlon, Nlat, Nlev, Ntime]

extrap : str, optional
Extrapolation method. Options are

- 'inverse_dist'
- 'nearest_s2d'

extrap_exp : float, optional
The exponent to raise the distance to when calculating weights for the
extrapolation method. If none are specified, defaults to 2.0

extrap_num_pnts : int, optional
The number of source points to use for the extrapolation methods
that use more than one source point. If none are specified, defaults to 8

ignore_degenerate : bool, optional
If False (default), raise error if grids contain degenerated cells
(i.e. triangles or lines, instead of quadrilaterals)
Expand All @@ -250,6 +267,22 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
raise ValueError('method should be chosen from '
'{}'.format(list(method_dict.keys())))

# use shorter, clearer names for options in ESMF.ExtrapMethod
extrap_dict = {'inverse_dist': ESMF.ExtrapMethod.NEAREST_IDAVG,
'nearest_s2d': ESMF.ExtrapMethod.NEAREST_STOD,
None: None
}
try:
esmf_extrap_method = extrap_dict[extrap]
except:
raise ValueError('method should be chosen from '
'{}'.format(list(extrap_dict.keys())))

# until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible
# if files are written on disk
if (extrap is not None) & (filename is not None):
raise ValueError('extrap cannot be used aongside a filename.')

# conservative regridding needs cell corner information
if method == 'conservative':
if not sourcegrid.has_corners:
Expand All @@ -275,9 +308,9 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
# Must set unmapped_action to IGNORE, otherwise the function will fail,
# if the destination grid is larger than the source grid.
regrid = ESMF.Regrid(sourcefield, destfield, filename=filename,
regrid_method=esmf_regrid_method,
unmapped_action=ESMF.UnmappedAction.IGNORE,
ignore_degenerate=ignore_degenerate)
regrid_method=esmf_regrid_method, extrap_method=esmf_extrap_method, extrap_dist_exponent=extrap_exp,
extrap_num_src_pnts=extrap_num_pnts, src_mask_values=[0], dst_mask_values=[0], unmapped_action=ESMF.UnmappedAction.IGNORE,
ignore_degenerate=ignore_degenerate, factors=True)

return regrid

Expand Down
99 changes: 50 additions & 49 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):

# tranpose the arrays so they become Fortran-ordered
grid = esmf_grid(lon.T, lat.T, periodic=periodic)
# detect ds["mask"] and add it to the grid
if 'mask' in ds.data_vars:
import ESMF
grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER)
grid.mask[0][...] = np.asarray(ds['mask']).T

if need_bounds:
lon_b = np.asarray(ds['lon_b'])
Expand Down Expand Up @@ -104,7 +109,8 @@ def ds_to_ESMFlocstream(ds):

class Regridder(object):
def __init__(self, ds_in, ds_out, method, periodic=False,
filename=None, reuse_weights=False, ignore_degenerate=None,
extrap=None, extrap_exp=None, extrap_num_pnts=None,
weights=None, ignore_degenerate=None,
locstream_in=False, locstream_out=False):
"""
Make xESMF regridder
Expand All @@ -113,8 +119,8 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
----------
ds_in, ds_out : xarray DataSet, or dictionary
Contain input and output grid coordinates. Look for variables
``lon``, ``lat``, and optionally ``lon_b``, ``lat_b`` for
conservative method.
``lon``, ``lat``, optionally ``lon_b``, ``lat_b`` for
conservative method, and ``mask``. Use 0 to identify cells to mask.

Shape can be 1D (n_lon,) and (n_lat,) for rectilinear grids,
or 2D (n_y, n_x) for general curvilinear grids.
Expand All @@ -134,16 +140,27 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
Only useful for global grids with non-conservative regridding.
Will be forced to False for conservative regridding.

filename : str, optional
Name for the weight file. The default naming scheme is::
extrap : str, optional
Extrapolation method. Options are

{method}_{Ny_in}x{Nx_in}_{Ny_out}x{Nx_out}.nc
- 'inverse_dist'
- 'nearest_s2d'

extrap_exp : float, optional
The exponent to raise the distance to when calculating weights for the
extrapolation method. If none are specified, defaults to 2.0

e.g. bilinear_400x600_300x400.nc
extrap_num_pnts : int, optional
The number of source points to use for the extrapolation methods
that use more than one source point. If none are specified, defaults to 8

reuse_weights : bool, optional
Whether to read existing weight file to save computing time.
False by default (i.e. re-compute, not reuse).
weights : None, coo_matrix, dict, str, Dataset, Path,
Regridding weights, stored as
- a scipy.sparse COO matrix,
- a dictionary with keys `row_dst`, `col_src` and `weights`,
- an xarray Dataset with data variables `col`, `row` and `S`,
- or a path to a netCDF file created by ESMF.
If None, compute the weights.

ignore_degenerate : bool, optional
If False (default), raise error if grids contain degenerated cells
Expand All @@ -170,7 +187,9 @@ def __init__(self, ds_in, ds_out, method, periodic=False,

self.method = method
self.periodic = periodic
self.reuse_weights = reuse_weights
self.extrap = extrap
self.extrap_exp = extrap_exp
self.extrap_num_pnts = extrap_num_pnts
self.ignore_degenerate = ignore_degenerate
self.locstream_in = locstream_in
self.locstream_out = locstream_out
Expand Down Expand Up @@ -226,14 +245,11 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
self.n_in = shape_in[0] * shape_in[1]
self.n_out = shape_out[0] * shape_out[1]

if filename is None:
self.filename = self._get_default_filename()
else:
self.filename = filename
if weights is None:
weights = self._compute_weights() # Dictionary of weights

# get weight matrix
self._write_weight_file()
self.weights = read_weights(self.filename, self.n_in, self.n_out)
# Convert weights, whatever their format, to a sparse coo matrix
self.weights = read_weights(weights, self.n_in, self.n_out)

@property
def A(self):
Expand Down Expand Up @@ -261,49 +277,24 @@ def _get_default_filename(self):

return filename

def _write_weight_file(self):

if os.path.exists(self.filename):
if self.reuse_weights:
print('Reuse existing file: {}'.format(self.filename))
return # do not compute it again, just read it
else:
print('Overwrite existing file: {} \n'.format(self.filename),
'You can set reuse_weights=True to save computing time.')
os.remove(self.filename)
else:
print('Create weight file: {}'.format(self.filename))

def _compute_weights(self):
regrid = esmf_regrid_build(self._grid_in, self._grid_out, self.method,
filename=self.filename,
extrap = self.extrap, extrap_exp = self.extrap_exp,
extrap_num_pnts = self.extrap_num_pnts,
ignore_degenerate=self.ignore_degenerate)
esmf_regrid_finalize(regrid) # only need weights, not regrid object

def clean_weight_file(self):
"""
Remove the offline weight file on disk.

To save the time on re-computing weights, you can just keep the file,
and set "reuse_weights=True" when initializing the regridder next time.
"""
if os.path.exists(self.filename):
print("Remove file {}".format(self.filename))
os.remove(self.filename)
else:
print("File {} is already removed.".format(self.filename))
w = regrid.get_weights_dict(deep_copy=True)
esmf_regrid_finalize(regrid) # only need weights, not regrid object
return w

def __repr__(self):
info = ('xESMF Regridder \n'
'Regridding algorithm: {} \n'
'Weight filename: {} \n'
'Reuse pre-computed weights? {} \n'
'Input grid shape: {} \n'
'Output grid shape: {} \n'
'Output grid dimension name: {} \n'
'Periodic in longitude? {}'
.format(self.method,
self.filename,
self.reuse_weights,
self.shape_in,
self.shape_out,
self.out_horiz_dims,
Expand Down Expand Up @@ -509,3 +500,13 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
ds_out = ds_out.squeeze(dim='dummy')

return ds_out

def to_netcdf(self, filename=None):
'''Save weights to disk as a netCDF file.'''
if filename is None:
filename = self._get_default_filename()
w = self.weights
ds = xr.Dataset({"S": w.data, "col": w.col + 1, "row": w.row + 1})
ds.to_netcdf(filename)
return filename

27 changes: 18 additions & 9 deletions xesmf/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import xarray as xr
import scipy.sparse as sps
import warnings
from pathlib import Path


def read_weights(filename, n_in, n_out):
def read_weights(weights, n_in, n_out):
'''
Read regridding weights into a scipy sparse COO matrix.

Expand All @@ -31,14 +32,22 @@ def read_weights(filename, n_in, n_out):
A : scipy sparse COO matrix.

'''
ds_w = xr.open_dataset(filename)

col = ds_w['col'].values - 1 # Python starts with 0
row = ds_w['row'].values - 1
S = ds_w['S'].values

weights = sps.coo_matrix((S, (row, col)), shape=[n_out, n_in])
return weights
if isinstance(weights, (str, Path, xr.Dataset)):
if not isinstance(weights, xr.Dataset):
ds_w = xr.open_dataset(weights)
col = ds_w['col'].values - 1 # Python starts with 0
row = ds_w['row'].values - 1
S = ds_w['S'].values

elif isinstance(weights, dict):
col = weights['col_src'] - 1
row = weights['row_dst'] - 1
S = weights['weights']

elif isinstance(weights, sps.coo_matrix):
return weights

return sps.coo_matrix((S, (row, col)), shape=[n_out, n_in])


def apply_weights(weights, indata, shape_in, shape_out):
Expand Down
16 changes: 16 additions & 0 deletions xesmf/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ def test_esmf_build_bilinear():
esmf_regrid_finalize(regrid)


def test_esmf_extrapolation():

grid_in = esmf_grid(lon_in.T, lat_in.T)
grid_out = esmf_grid(lon_out.T, lat_out.T)

regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear')
data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T
# without extrapolation, the first and last lines/columns = 0
assert data_out_esmpy[0, 0] == 0

regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear', extrap='inverse_dist', extrap_num_pnts=3, extrap_exp=1)
data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T
# the 3 closest points in data_in are 2.010, 2.005, and 1.992. The result should be roughly equal to 2.0
assert np.round(data_out_esmpy[0, 0], 1) == 2.0


def test_regrid():

# use conservative regridding as an example,
Expand Down
Loading