Skip to content

Commit

Permalink
Added num_threads keyword arg to interpolate functions.
Browse files Browse the repository at this point in the history
Added logic to the top-level module interpolate and
interpolate_to_stations functions to handble keyword argument
num_threads (default value of 1). If grib2io_interp was built with
OpenMP support, then num_threads will be used to set OpenMP number
of threads. The interpolate functions will set the number of threads
back to the original value prior to function call.

Grib2Message interp methods and xarray backend accessor methods have
be updated to use num_threads.
  • Loading branch information
EricEngle-NOAA committed Apr 7, 2024
1 parent 0aa2fc4 commit b0d2245
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 17 deletions.
55 changes: 52 additions & 3 deletions src/grib2io/_grib2io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,8 @@ def to_bytes(self, validate: bool=True):
return None


def interpolate(self, method, grid_def_out, method_options=None, drtn=None):
def interpolate(self, method, grid_def_out, method_options=None, drtn=None,
num_threads=1):
"""
Grib2Message Interpolator
Expand Down Expand Up @@ -1276,6 +1277,10 @@ def interpolate(self, method, grid_def_out, method_options=None, drtn=None):
template of the source GRIB2 message is used. Once again, it is the
user's responsibility to properly set the Data Representation
Template attributes.
num_threads : int, optional
Number of OpenMP threads to use for interpolation. The default
value is 1. If grib2io_interp was not build with OpenMP, then
this keyword argument and value will have no impact.
Returns
-------
Expand Down Expand Up @@ -1481,7 +1486,8 @@ def set_auto_nans(value: bool):
raise TypeError(f"Argument must be bool")


def interpolate(a, method: Union[int, str], grid_def_in, grid_def_out, method_options=None):
def interpolate(a, method: Union[int, str], grid_def_in, grid_def_out,
method_options=None, num_threads=1):
"""
This is the module-level interpolation function.
Expand Down Expand Up @@ -1521,6 +1527,10 @@ def interpolate(a, method: Union[int, str], grid_def_in, grid_def_out, method_op
method_options : list of ints, optional
Interpolation options. See the NCEPLIBS-ip documentation for
more information on how these are used.
num_threads : int, optional
Number of OpenMP threads to use for interpolation. The default
value is 1. If grib2io_interp was not build with OpenMP, then
this keyword argument and value will have no impact.
Returns
-------
Expand All @@ -1530,8 +1540,18 @@ def interpolate(a, method: Union[int, str], grid_def_in, grid_def_out, method_op
the assumptions that 0-index is the interpolated u and 1-index is the
interpolated v.
"""
import grib2io_interp
from grib2io_interp import interpolate

prev_num_threads = 1
try:
import grib2io_interp
if grib2io_interp.has_openmp_support:
prev_num_threads = grib2io_interp.get_openmp_threads()
grib2io_interp.set_openmp_threads(num_threads)
except(AttributeError):
pass

if isinstance(method,int) and method not in _interp_schemes.values():
raise ValueError('Invalid interpolation method.')
elif isinstance(method,str):
Expand Down Expand Up @@ -1604,10 +1624,18 @@ def interpolate(a, method: Union[int, str], grid_def_in, grid_def_out, method_op

del rlat
del rlon

try:
if grib2io_interp.has_openmp_support:
grib2io_interp.set_openmp_threads(prev_num_threads)
except(AttributeError):
pass

return out


def interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=None):
def interpolate_to_stations(a, method, grid_def_in, lats, lons,
method_options=None, num_threads=1):
"""
Module-level interpolation function for interpolation to stations.
Expand Down Expand Up @@ -1651,6 +1679,10 @@ def interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=N
method_options : list of ints, optional
Interpolation options. See the NCEPLIBS-ip documentation for
more information on how these are used.
num_threads : int, optional
Number of OpenMP threads to use for interpolation. The default
value is 1. If grib2io_interp was not build with OpenMP, then
this keyword argument and value will have no impact.
Returns
-------
Expand All @@ -1660,8 +1692,18 @@ def interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=N
the assumptions that 0-index is the interpolated u and 1-index is the
interpolated v.
"""
import grib2io_interp
from grib2io_interp import interpolate

prev_num_threads = 1
try:
import grib2io_interp
if grib2io_interp.has_openmp_support:
prev_num_threads = grib2io_interp.get_openmp_threads()
grib2io_interp.set_openmp_threads(num_threads)
except(AttributeError):
pass

if isinstance(method,int) and method not in _interp_schemes.values():
raise ValueError('Invalid interpolation method.')
elif isinstance(method,str):
Expand Down Expand Up @@ -1735,6 +1777,13 @@ def interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=N

del rlat
del rlon

try:
if grib2io_interp.has_openmp_support:
grib2io_interp.set_openmp_threads(prev_num_threads)
except(AttributeError):
pass

return out


Expand Down
36 changes: 23 additions & 13 deletions src/grib2io/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,18 +670,20 @@ def __eq__(self, other):
return ordered_frames, cube, extra_geo


def interp_nd(a,*, method, grid_def_in, grid_def_out, method_options=None):
def interp_nd(a,*, method, grid_def_in, grid_def_out, method_options=None, num_threads=1):
front_shape = a.shape[:-2]
a = a.reshape(-1,a.shape[-2],a.shape[-1])
a = grib2io.interpolate(a, method, grid_def_in, grid_def_out, method_options=method_options)
a = grib2io.interpolate(a, method, grid_def_in, grid_def_out, method_options=method_options,
num_threads=num_threads)
a = a.reshape(front_shape + (a.shape[-2], a.shape[-1]))
return a


def interp_nd_stations(a,*, method, grid_def_in, lats, lons, method_options=None):
def interp_nd_stations(a,*, method, grid_def_in, lats, lons, method_options=None, num_threads=1):
front_shape = a.shape[:-2]
a = a.reshape(-1,a.shape[-2],a.shape[-1])
a = grib2io.interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=method_options)
a = grib2io.interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=method_options,
num_threads=num_threads)
a = a.reshape(front_shape + (len(lats),))
return a

Expand All @@ -697,20 +699,22 @@ def griddef(self):
return Grib2GridDef.from_section3(self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'])


def interp(self, method, grid_def_out, method_options=None) -> xr.Dataset:
def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.Dataset:
# see interp method of class Grib2ioDataArray
da = self._obj.to_array()
da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
da = da.grib2io.interp(method, grid_def_out, method_options=method_options)
da = da.grib2io.interp(method, grid_def_out, method_options=method_options,
num_threads=num_threads)
ds = da.to_dataset(dim='variable')
return ds


def interp_to_stations(self, method, calls, lats, lons, method_options=None) -> xr.Dataset:
def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.Dataset:
# see interp_to_stations method of class Grib2ioDataArray
da = self._obj.to_array()
da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']
da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options)
da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options,
num_threads=num_threads)
ds = da.to_dataset(dim='variable')
return ds

Expand Down Expand Up @@ -743,7 +747,7 @@ def griddef(self):
return Grib2GridDef.from_section3(self._obj.attrs['GRIB2IO_section3'])


def interp(self, method, grid_def_out, method_options=None) -> xr.DataArray:
def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.DataArray:
"""
Perform grid spatial interpolation.
Expand All @@ -763,9 +767,15 @@ def interp(self, method, grid_def_out, method_options=None) -> xr.DataArray:
| 'budget' | 3 |
| 'spectral' | 4 |
| 'neighbor-budget' | 6 |
grid_def_out
Grib2GridDef object of the output grid.
method_options : list of ints, optional
Interpolation options. See the NCEPLIBS-ip documentation for
more information on how these are used.
num_threads : int, optional
Number of OpenMP threads to use for interpolation. The default
value is 1. If grib2io_interp was not build with OpenMP, then
this keyword argument and value will have no impact.
Returns
-------
Expand Down Expand Up @@ -800,7 +810,7 @@ def interp(self, method, grid_def_out, method_options=None) -> xr.DataArray:
if da.chunks is None:
data = interp_nd(da.data, method=method, grid_def_in=grid_def_in,
grid_def_out=grid_def_out,
method_options=method_options)
method_options=method_options,num_threads=num_threads)
else:
import dask
front_shape = da.shape[:-2]
Expand All @@ -815,7 +825,7 @@ def interp(self, method, grid_def_out, method_options=None) -> xr.DataArray:
return new_da


def interp_to_stations(self, method, calls, lats, lons, method_options=None) -> xr.DataArray:
def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.DataArray:
"""
Perform spatial interpolation to station points.
Expand Down Expand Up @@ -872,7 +882,7 @@ def interp_to_stations(self, method, calls, lats, lons, method_options=None) ->

if da.chunks is None:
data = interp_nd_stations(da.data, method=method, grid_def_in=grid_def_in, lats=lats,
lons=lons, method_options=method_options)
lons=lons, method_options=method_options, num_threads=num_threads)
else:
import dask
front_shape = da.shape[:-1]
Expand Down
17 changes: 16 additions & 1 deletion tests/test_xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def test_multi_lead(request):
da = xr.open_mfdataset([data / 'gfs.t00z.pgrb2.1p00.f009_subset', data / 'gfs.t00z.pgrb2.1p00.f012_subset'], engine='grib2io', filters=filters, combine='nested', concat_dim='leadTime').to_array()
assert da.shape == (1, 2, 181, 360)


def test_interp(request):
try:
from grib2io._grib2io import Grib2GridDef
Expand All @@ -31,3 +30,19 @@ def test_interp(request):
assert da.shape == (1, 1597, 2345)
except(ModuleNotFoundError):
pytest.skip()

def test_interp_with_openmp_threads(request):
try:
from grib2io._grib2io import Grib2GridDef
gdtn_nbm = 30
gdt_nbm = [1, 0, 6371200, 255, 255, 255, 255, 2345, 1597, 19229000, 233723400,
48, 25000000, 265000000, 2539703, 2539703, 0, 64, 25000000,
25000000, -90000000, 0]
nbm_grid_def = Grib2GridDef(gdtn_nbm, gdt_nbm)
data = request.config.rootdir / 'tests' / 'data' / 'gfs_20221107'
filters = dict(productDefinitionTemplateNumber=0, typeOfFirstFixedSurface=1)
ds = xr.open_dataset(data / 'gfs.t00z.pgrb2.1p00.f012_subset', engine='grib2io', filters=filters)
da = ds.grib2io.interp('neighbor', nbm_grid_def, num_threads=2).to_array()
assert da.shape == (1, 1597, 2345)
except(ModuleNotFoundError):
pytest.skip()

0 comments on commit b0d2245

Please sign in to comment.