Skip to content

Commit

Permalink
Merge branch 'main' into cookiecutter-update
Browse files Browse the repository at this point in the history
  • Loading branch information
aulemahal authored Jan 29, 2024
2 parents 768b68a + 449c399 commit 3306986
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 20 deletions.
2 changes: 1 addition & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies:
- netCDF4
- numcodecs
- numpy
- pandas >=2.0
- pandas >=2.0,<2.2
- parse
- pyyaml
- rechunker
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies:
- netCDF4
- numcodecs
- numpy
- pandas >=2.0
- pandas >=2.0,<2.2
- parse
- pyyaml
- rechunker
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies = [
"netCDF4",
"numcodecs",
"numpy",
"pandas>=2.0",
"pandas>=2.0,<2.2",
"parse",
# Used when opening catalogs.
"pyarrow",
Expand Down
71 changes: 71 additions & 0 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np

from xscen.regrid import create_bounds_rotated_pole, regrid_dataset
from xscen.testing import datablock_3d


def test_create_bounds_rotated_pole():
ds = datablock_3d(
np.zeros((20, 10, 10)),
"tas",
"rlon",
-5,
"rlat",
80.5,
1,
1,
"2000-01-01",
as_dataset=True,
)
bnds = create_bounds_rotated_pole(ds)
np.testing.assert_allclose(bnds.lon_bounds[-1, -1, 1], 83)
np.testing.assert_allclose(bnds.lat_bounds[-1, -1, 1], 42.5)


class TestRegridDataset:
def test_simple(self, tmp_path):
dsout = datablock_3d(
np.zeros((2, 10, 10)),
"tas",
"rlon",
-5,
"rlat",
-5,
1,
1,
"2000-01-01",
as_dataset=True,
)
dsout.attrs["cat:domain"] = "Région d'essai"

dsin = datablock_3d(
np.zeros((10, 6, 6)),
"tas",
"lon",
-142,
"lat",
0,
2,
2,
"2000-01-01",
as_dataset=True,
)
dsin = dsin.chunk({"lon": 3, "time": 1})

out = regrid_dataset(
dsin,
dsout,
tmp_path / "weights",
regridder_kwargs={
"method": "patch",
"output_chunks": {"rlon": 5},
"unmapped_to_nan": True,
},
)

assert (tmp_path / "weights" / "weights_regrid0patch.nc").is_file()
assert out.tas.attrs["grid_mapping"] == "rotated_pole"
assert out.rotated_pole.attrs == dsout.rotated_pole.attrs
assert "patch" in out.attrs["history"]
assert out.attrs["cat:processing_level"] == "regridded"
assert out.chunks["rlon"] == (5, 5)
6 changes: 3 additions & 3 deletions xscen/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def spatial_mean( # noqa: C901
This is simply a shortcut for `{'name': 'global', 'method': 'bbox', 'lon_bnds' [-180, 180], 'lat_bnds': [-90, 90]}`.
kwargs : dict, optional
Arguments to send to either mean(), interp() or SpatialAverager().
For SpatialAverager, one can give `skipna` or `out_chunks` here, to be passed to the averager call itself.
For SpatialAverager, one can give `skipna` or `output_chunks` here, to be passed to the averager call itself.
simplify_tolerance : float, optional
Precision (in degree) used to simplify a shapefile before sending it to SpatialAverager().
The simpler the polygons, the faster the averaging, but it will lose some precision.
Expand Down Expand Up @@ -973,8 +973,8 @@ def spatial_mean( # noqa: C901

kwargs_copy = deepcopy(kwargs)
call_kwargs = {"skipna": kwargs_copy.pop("skipna", False)}
if "out_chunks" in kwargs:
call_kwargs["out_chunks"] = kwargs_copy.pop("out_chunks")
if "output_chunks" in kwargs:
call_kwargs["output_chunks"] = kwargs_copy.pop("output_chunks")

# Pre-emptive segmentization. Same threshold as xESMF, but there's not strong analysis behind this choice
geoms = shapely.segmentize(polygon.geometry, 1)
Expand Down
10 changes: 4 additions & 6 deletions xscen/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def regrid_dataset( # noqa: C901
Destination grid. The Dataset needs to have lat/lon coordinates.
Supports a 'mask' variable compatible with ESMF standards.
regridder_kwargs : dict, optional
Arguments to send xe.Regridder(). If it contains `skipna` or `out_chunks`, those
Arguments to send xe.Regridder(). If it contains `skipna` or `output_chunks`, those
are passed to the regridder call directly.
intermediate_grids : dict, optional
This argument is used to do a regridding in many steps, regridding to regular
Expand Down Expand Up @@ -114,8 +114,6 @@ def regrid_dataset( # noqa: C901
f"{'_'.join(kwargs[k] for k in kwargs if isinstance(kwargs[k], str))}.nc",
)

# TODO: Support for conservative regridding (use xESMF to add corner information), Locstreams, etc.

# Re-use existing weight file if possible
if os.path.isfile(weights_filename) and not (
("reuse_weights" in kwargs) and (kwargs["reuse_weights"] is False)
Expand All @@ -124,10 +122,10 @@ def regrid_dataset( # noqa: C901
kwargs["reuse_weights"] = True

# Extract args that are to be given at call time.
# out_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs
# output_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs
call_kwargs = {"skipna": regridder_kwargs.pop("skipna", False)}
if "out_chunks" in regridder_kwargs:
call_kwargs["out_chunks"] = regridder_kwargs.pop("out_chunks")
if "output_chunks" in regridder_kwargs:
call_kwargs["output_chunks"] = regridder_kwargs.pop("output_chunks")

regridder = _regridder(
ds_in=ds, ds_grid=ds_grid, filename=weights_filename, **regridder_kwargs
Expand Down
31 changes: 23 additions & 8 deletions xscen/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Optional, Union

import cartopy.crs as ccrs
import numpy as np
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -117,18 +118,31 @@ def datablock_3d(

# Support for rotated pole and oblique mercator grids
if x != "lon" and y != "lat":
lat, lon = np.meshgrid(
np.arange(45, 45 + values.shape[1] * y_step, y_step),
np.arange(-75, -75 + values.shape[2] * x_step, x_step),
)
da["lat"] = xr.DataArray(np.flipud(lat.T), dims=[y, x], attrs=attrs["lat"])
da["lon"] = xr.DataArray(lon.T, dims=[y, x], attrs=attrs["lon"])
da.attrs["grid_mapping"] = "rotated_pole" if x == "rlon" else "oblique_mercator"
PC = ccrs.PlateCarree()
if x == "rlon": # rotated pole
GM = ccrs.RotatedPole(
pole_longitude=42.5, pole_latitude=83.0, central_rotated_longitude=0.0
)
da.attrs["grid_mapping"] = "rotated_pole"
else:
GM = ccrs.ObliqueMercator(
azimuth=90,
central_latitude=46,
central_longitude=-63,
scale_factor=1,
false_easting=0,
false_northing=0,
)
da.attrs["grid_mapping"] = "oblique_mercator"

YY, XX = xr.broadcast(da[y], da[x])
pts = PC.transform_points(GM, XX.values, YY.values)
da["lon"] = xr.DataArray(pts[..., 0], dims=XX.dims, attrs=attrs["lon"])
da["lat"] = xr.DataArray(pts[..., 1], dims=YY.dims, attrs=attrs["lat"])

if as_dataset:
if "grid_mapping" in da.attrs:
da = da.to_dataset()
# These grid_mapping attributes are simply placeholders and won't match the data
if da[variable].attrs["grid_mapping"] == "rotated_pole":
da = da.assign_coords(
{
Expand All @@ -138,6 +152,7 @@ def datablock_3d(
"grid_mapping_name": "rotated_latitude_longitude",
"grid_north_pole_latitude": 42.5,
"grid_north_pole_longitude": 83.0,
"north_pole_grid_longitude": 0.0,
},
)
}
Expand Down

0 comments on commit 3306986

Please sign in to comment.