diff --git a/monet/monet_accessor.py b/monet/monet_accessor.py index 47b260e..4b1ab9e 100644 --- a/monet/monet_accessor.py +++ b/monet/monet_accessor.py @@ -62,7 +62,13 @@ def _monet_to_latlon(da): return dset -def _dataset_to_monet(dset, lat_name="latitude", lon_name="longitude", latlon2d=False): +def _dataset_to_monet( + dset, + lat_name="latitude", + lon_name="longitude", + latlon2d=None, + lon180=None, +): """Rename xarray DataArray or Dataset coordinate variables for use with monet functions, returning a new xarray object. @@ -74,73 +80,68 @@ def _dataset_to_monet(dset, lat_name="latitude", lon_name="longitude", latlon2d= Name of the latitude array. lon_name : str Name of the longitude array. - latlon2d : bool + latlon2d : bool, optional Whether the latitude and longitude data is two-dimensional. + If unset (``None``), guess based on dim count. + lon180 : bool, optional + Whether the longitude values are in the range [-180, 180) already. + If true, longitude wrapping/normalization, + which can introduce small floating point errors, will be skipped. + If unset (``None``), compute min/max to determine. """ - if "grid_xt" in dset.dims: - # GFS v16 file - try: - if isinstance(dset, xr.DataArray): - dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt") - elif isinstance(dset, xr.Dataset): - dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt") - else: - raise ValueError - except ValueError: - print("dset must be an xarray.DataArray or xarray.Dataset") + if not isinstance(dset, (xr.DataArray, xr.Dataset)): + raise TypeError("dset must be an xarray.DataArray or xarray.Dataset") + + if "grid_xt" in dset.dims: # GFS v16 file + if isinstance(dset, xr.DataArray): + dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt") + elif isinstance(dset, xr.Dataset): + dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt") if "south_north" in dset.dims: # WRF WPS file dset = dset.rename(dict(south_north="y", west_east="x")) - try: - if isinstance(dset, xr.Dataset): - if "XLAT_M" in dset.data_vars: - dset["XLAT_M"] = dset.XLAT_M.squeeze() - dset["XLONG_M"] = dset.XLONG_M.squeeze() - dset = dset.set_coords(["XLAT_M", "XLONG_M"]) - elif "XLAT" in dset.data_vars: - dset["XLAT"] = dset.XLAT.squeeze() - dset["XLONG"] = dset.XLONG.squeeze() - dset = dset.set_coords(["XLAT", "XLONG"]) - elif isinstance(dset, xr.DataArray): - if "XLAT_M" in dset.coords: - dset["XLAT_M"] = dset.XLAT_M.squeeze() - dset["XLONG_M"] = dset.XLONG_M.squeeze() - elif "XLAT" in dset.coords: - dset["XLAT"] = dset.XLAT.squeeze() - dset["XLONG"] = dset.XLONG.squeeze() - else: - raise ValueError - except ValueError: - print("dset must be an Xarray.DataArray or Xarray.Dataset") + if isinstance(dset, xr.Dataset): + if "XLAT_M" in dset.data_vars: + dset["XLAT_M"] = dset.XLAT_M.squeeze() + dset["XLONG_M"] = dset.XLONG_M.squeeze() + dset = dset.set_coords(["XLAT_M", "XLONG_M"]) + elif "XLAT" in dset.data_vars: + dset["XLAT"] = dset.XLAT.squeeze() + dset["XLONG"] = dset.XLONG.squeeze() + dset = dset.set_coords(["XLAT", "XLONG"]) + elif isinstance(dset, xr.DataArray): + if "XLAT_M" in dset.coords: + dset["XLAT_M"] = dset.XLAT_M.squeeze() + dset["XLONG_M"] = dset.XLONG_M.squeeze() + elif "XLAT" in dset.coords: + dset["XLAT"] = dset.XLAT.squeeze() + dset["XLONG"] = dset.XLONG.squeeze() + + # Rename lat/lon coordinates to 'latitude'/'longitude' + dset = _rename_to_monet_latlon(dset) # common cases + if (isinstance(dset, xr.Dataset) and not {"latitude", "longitude"} <= set(dset.variables)) or ( + isinstance(dset, xr.DataArray) and not {"latitude", "longitude"} <= set(dset.coords) + ): + dset = dset.rename({lat_name: "latitude", lon_name: "longitude"}) - # Unstructured Grid - # lat & lon are not coordinate variables in unstructured grid - if dset.attrs.get("mio_has_unstructured_grid", False): - # only call rename and wrap_longitudes - dset = _rename_to_monet_latlon(dset) + # Maybe wrap longitudes + if lon180 is None: + lon180 = dset["longitude"].min() >= -180 and dset["longitude"].max() < 180 + if not lon180: dset["longitude"] = wrap_longitudes(dset["longitude"]) - else: - dset = _rename_to_monet_latlon(dset) - latlon2d = True - # print(len(dset[lat_name].shape)) - # print(dset) - if len(dset[lat_name].shape) < 2: - # print(dset[lat_name].shape) - latlon2d = False - if latlon2d is False: - try: - if isinstance(dset, xr.DataArray): - dset = _dataarray_coards_to_netcdf(dset, lat_name=lat_name, lon_name=lon_name) - elif isinstance(dset, xr.Dataset): - dset = _coards_to_netcdf(dset, lat_name=lat_name, lon_name=lon_name) - else: - raise ValueError - except ValueError: - print("dset must be an Xarray.DataArray or Xarray.Dataset") - else: - dset = _rename_to_monet_latlon(dset) - dset["longitude"] = wrap_longitudes(dset["longitude"]) + # lat & lon are not coordinate variables in unstructured grid, so we're done + if dset.attrs.get("mio_has_unstructured_grid", False): + return dset + + # Maybe convert 1-D lat/lon coords to 2-D + if latlon2d is None: + latlon2d = dset["latitude"].ndim >= 2 + if not latlon2d: + if isinstance(dset, xr.DataArray): + dset = _dataarray_coards_to_netcdf(dset, lat_name="latitude", lon_name="longitude") + elif isinstance(dset, xr.Dataset): + dset = _coards_to_netcdf(dset, lat_name="latitude", lon_name="longitude") return dset @@ -171,7 +172,7 @@ def _rename_to_monet_latlon(ds): elif "XLAT" in check_list: return ds.rename({"XLAT": "latitude", "XLONG": "longitude"}) else: - return ds + return ds.copy() def _coards_to_netcdf(dset, lat_name="lat", lon_name="lon"): @@ -189,7 +190,7 @@ def _coards_to_netcdf(dset, lat_name="lat", lon_name="lon"): """ from numpy import arange, meshgrid - lon = wrap_longitudes(dset[lon_name]) + lon = dset[lon_name] lat = dset[lat_name] lons, lats = meshgrid(lon, lat) x = arange(len(lon)) @@ -218,7 +219,7 @@ def _dataarray_coards_to_netcdf(dset, lat_name="lat", lon_name="lon"): """ from numpy import arange, meshgrid - lon = wrap_longitudes(dset[lon_name]) + lon = dset[lon_name] lat = dset[lat_name] lons, lats = meshgrid(lon, lat) x = arange(len(lon)) @@ -1191,7 +1192,7 @@ def _get_CoordinateDefinition(self, data=None): g = geo.CoordinateDefinition(lats=self._obj.latitude, lons=self._obj.longitude) return g - def remap_nearest(self, data, **kwargs): + def remap_nearest(self, data, radius_of_influence=1e6, **kwargs): """Remap `data` from another grid to the current self grid using pyresample nearest-neighbor interpolation. @@ -1213,16 +1214,20 @@ def remap_nearest(self, data, **kwargs): # from .grids import get_generic_projection_from_proj4 # check to see if grid is supplied + source_data = _dataset_to_monet(data) target_data = _dataset_to_monet(self._obj) - source = self._get_CoordinateDefinition(data=source_data) - target = self._get_CoordinateDefinition(data=target_data) - r = kd_tree.XArrayResamplerNN(source, target, **kwargs) + source = self._get_CoordinateDefinition(source_data) + target = self._get_CoordinateDefinition(target_data) + r = kd_tree.XArrayResamplerNN( + source, target, radius_of_influence=radius_of_influence, **kwargs + ) r.get_neighbour_info() if isinstance(source_data, xr.DataArray): result = r.get_sample_from_neighbour_info(source_data) result.name = source_data.name result["latitude"] = target_data.latitude + result["longitude"] = target_data.longitude elif isinstance(source_data, xr.Dataset): results = {} @@ -1504,7 +1509,7 @@ def _get_CoordinateDefinition(self, data=None): g = geo.CoordinateDefinition(lats=self._obj.latitude, lons=self._obj.longitude) return g - def remap_nearest(self, data, radius_of_influence=1e6): + def remap_nearest(self, data, radius_of_influence=1e6, **kwargs): """Remap `data` from another grid to the current self grid using pyresample nearest-neighbor interpolation. @@ -1525,26 +1530,20 @@ def remap_nearest(self, data, radius_of_influence=1e6): # from .grids import get_generic_projection_from_proj4 # check to see if grid is supplied - try: - check_error = False - if isinstance(data, xr.DataArray) or isinstance(data, xr.Dataset): - check_error = False - else: - check_error = True - if check_error: - raise TypeError - except TypeError: - print("data must be either an Xarray.DataArray or Xarray.Dataset") + source_data = _dataset_to_monet(data) target_data = _dataset_to_monet(self._obj) source = self._get_CoordinateDefinition(source_data) target = self._get_CoordinateDefinition(target_data) - r = kd_tree.XArrayResamplerNN(source, target, radius_of_influence=radius_of_influence) + r = kd_tree.XArrayResamplerNN( + source, target, radius_of_influence=radius_of_influence, **kwargs + ) r.get_neighbour_info() if isinstance(source_data, xr.DataArray): result = r.get_sample_from_neighbour_info(source_data) result.name = source_data.name result["latitude"] = target_data.latitude + result["longitude"] = target_data.longitude elif isinstance(source_data, xr.Dataset): results = {} diff --git a/monet/util/combinetool.py b/monet/util/combinetool.py index f407279..7118863 100644 --- a/monet/util/combinetool.py +++ b/monet/util/combinetool.py @@ -71,7 +71,7 @@ def combine_da_to_da(source, target, *, merge=True, interp_time=False, **kwargs) ---------- source : xarray.DataArray or xarray.Dataset Gridded data. - target : xarray.DataArray + target : xarray.DataArray or xarray.Dataset Point observations. merge : bool If false, only return the interpolated source data. @@ -87,13 +87,14 @@ def combine_da_to_da(source, target, *, merge=True, interp_time=False, **kwargs) """ from ..monet_accessor import _dataset_to_monet - target_fixed = _dataset_to_monet(target) - source_fixed = _dataset_to_monet(source) - output = target_fixed.monet.remap_nearest(source_fixed, **kwargs) + output = target.monet.remap_nearest(source, **kwargs) + if interp_time: output = output.interp(time=target.time) + if merge: - output = xr.merge([target_fixed, output]) + output = xr.merge([_dataset_to_monet(target), output]) + return output diff --git a/tests/test_remap.py b/tests/test_remap.py index 7d2ac4f..a37ea73 100644 --- a/tests/test_remap.py +++ b/tests/test_remap.py @@ -90,6 +90,11 @@ def test_combine_da_da(): }, ) + # Longitude normalization introduces floating point error + x_ = (x + 180) % 360 - 180 + assert not (x_ == x).any() + assert np.abs(x_ - x).max() < 5e-14 + # Combine (find closest model grid cell to each obs point) # NOTE: to use `merge`, must have matching `level` dims new = combine_da_to_da(model, obs, merge=False, interp_time=False) @@ -100,8 +105,10 @@ def test_combine_da_da(): assert float(new.longitude.max()) == pytest.approx(0.9) assert float(new.latitude.min()) == pytest.approx(0.1) assert float(new.latitude.max()) == pytest.approx(0.9) - assert (new.latitude.isel(x=0).values == obs.latitude.values).all() - assert np.allclose(new.longitude.isel(y=0).values, obs.longitude.values) + + assert (obs.longitude.values == x).all(), "preserved" + assert (new.latitude.isel(x=0).values == obs.latitude.values).all(), "same as target" + assert (new.longitude.isel(y=0).values == obs.longitude.values).all(), "same as target" # Use orthogonal selection to get track a = new.data.values[:, new.y, new.x]