From fc2ea183b4a4bc4cf5677f175fa3fd194d22032f Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Tue, 18 Jun 2024 16:10:24 +1000 Subject: [PATCH] refactor: nodata handling #162 - adding new types - `MaybeAutoNodata` -- `None|int|float|str|"auto"` - `Nodata` is now `None|float|int` - `MaybeNodata` is now `None|float|int|str - "auto" replaces what used to be `None` - `None` now means "no nodata value" - `resolve_nodata()` is used to handle nodata options consistently across the library - default nodata for float is `nan` - fixes in reproject/overview generation for float data. It is assumed to have `nan` values, and GDAL needs `nodata=nan` to handle it correctly. --- odc/geo/_dask.py | 10 ++---- odc/geo/_rgba.py | 24 +++++++++---- odc/geo/_xr_interop.py | 33 ++++++++--------- odc/geo/cog/_rio.py | 19 +++++----- odc/geo/cog/_shared.py | 4 +-- odc/geo/cog/_tifffile.py | 9 +++-- odc/geo/math.py | 46 ++++++++++++++++++++++++ odc/geo/types.py | 6 ++-- odc/geo/warp.py | 77 +++++++++++++--------------------------- tests/conftest.py | 19 +++++++++- tests/test_cog.py | 2 +- tests/test_map.py | 2 ++ tests/test_math.py | 52 +++++++++++++++++++++++++++ tests/test_warp.py | 68 +++++++++++++++++++++++++++++++++++ tests/test_xr_interop.py | 7 +++- 15 files changed, 279 insertions(+), 99 deletions(-) create mode 100644 tests/test_warp.py diff --git a/odc/geo/_dask.py b/odc/geo/_dask.py index 004583c1..5b69c685 100644 --- a/odc/geo/_dask.py +++ b/odc/geo/_dask.py @@ -9,13 +9,9 @@ from ._blocks import BlockAssembler from .gcp import GCPGeoBox from .geobox import GeoBox, GeoboxTiles -from .warp import ( - Nodata, - Resampling, - _rio_reproject, - resampling_s2rio, - resolve_fill_value, -) +from .math import resolve_fill_value +from .types import Nodata +from .warp import Resampling, _rio_reproject, resampling_s2rio def _do_chunked_reproject( diff --git a/odc/geo/_rgba.py b/odc/geo/_rgba.py index f99efb2e..5d7c4865 100644 --- a/odc/geo/_rgba.py +++ b/odc/geo/_rgba.py @@ -8,6 +8,7 @@ import xarray as xr from ._interop import is_dask_collection +from .types import Nodata # pylint: disable=import-outside-toplevel @@ -59,7 +60,7 @@ def _np_to_rgba( r: np.ndarray, g: np.ndarray, b: np.ndarray, - nodata: Optional[float], + nodata: Nodata, vmin: float, vmax: float, ) -> np.ndarray: @@ -67,7 +68,7 @@ def _np_to_rgba( if r.dtype.kind == "f": valid = ~np.isnan(r) - if nodata is not None: + if nodata is not None and not np.isnan(nodata): valid = valid * (r != nodata) elif nodata is not None: valid = r != nodata @@ -130,7 +131,7 @@ def to_rgba( assert vmax is not None _b = ds[bands[0]] - nodata = getattr(_b, "nodata", None) + nodata = _b.odc.nodata dims = (*_b.dims, "band") r, g, b = (ds[name].data for name in bands) @@ -171,12 +172,20 @@ def _np_colorize(x, cmap, clip): return cmap[x] -def _matplotlib_colorize(x, cmap, vmin=None, vmax=None, nodata=None, robust=False): +def _matplotlib_colorize( + x, + cmap, + vmin=None, + vmax=None, + nodata: Nodata = None, + robust=False, +): from matplotlib import colormaps from matplotlib.colors import Normalize if cmap is None or isinstance(cmap, str): - cmap = colormaps.get_cmap(cmap) + # None is a valid input, maps to default cmap + cmap = colormaps.get_cmap(cmap) # type: ignore if nodata is not None: x = np.where(x == nodata, np.float32("nan"), x) @@ -234,8 +243,11 @@ def colorize( :param clip: If ``True`` clip values from ``x`` to be in the safe range for ``cmap``. """ # pylint: disable=too-many-locals + from ._xr_interop import ODCExtensionDa assert isinstance(x, xr.DataArray) + assert isinstance(x.odc, ODCExtensionDa) + _is_dask = is_dask_collection(x.data) if isinstance(cmap, np.ndarray): @@ -263,7 +275,7 @@ def colorize( _matplotlib_colorize, vmin=vmin, vmax=vmax, - nodata=getattr(x, "nodata", None), + nodata=x.odc.nodata, robust=robust, ) nc, cmap_dtype = 4, "uint8" diff --git a/odc/geo/_xr_interop.py b/odc/geo/_xr_interop.py index 45ffeec9..c696d63f 100644 --- a/odc/geo/_xr_interop.py +++ b/odc/geo/_xr_interop.py @@ -39,13 +39,15 @@ affine_from_axis, approx_equal_affine, is_affine_st, + is_nodata_empty, maybe_int, resolution_from_affine, + resolve_fill_value, + resolve_nodata, ) from .overlap import compute_output_geobox from .roi import roi_is_empty -from .types import Resolution, SomeResolution, SomeShape, xy_ -from .warp import resolve_fill_value +from .types import MaybeAutoNodata, Nodata, Resolution, SomeResolution, SomeShape, xy_ # pylint: disable=import-outside-toplevel # pylint: disable=too-many-lines @@ -656,7 +658,7 @@ def xr_reproject( how: Union[SomeCRS, GeoBox], *, resampling: Union[str, int] = "nearest", - dst_nodata: Optional[float] = None, + dst_nodata: MaybeAutoNodata = "auto", dtype=None, resolution: Union[SomeResolution, Literal["auto", "fit", "same"]] = "auto", shape: Union[SomeShape, int, None] = None, @@ -753,7 +755,7 @@ def _xr_reproject_ds( how: Union[SomeCRS, GeoBox], *, resampling: Union[str, int] = "nearest", - dst_nodata: Optional[float] = None, + dst_nodata: MaybeAutoNodata = "auto", dtype=None, **kw, ) -> xarray.Dataset: @@ -797,7 +799,7 @@ def _xr_reproject_da( how: Union[SomeCRS, GeoBox], *, resampling: Union[str, int] = "nearest", - dst_nodata: Optional[float] = None, + dst_nodata: MaybeAutoNodata = "auto", dtype=None, **kw, ) -> xarray.DataArray: @@ -828,9 +830,8 @@ def _xr_reproject_da( assert ydim + 1 == src.odc.xdim dst_shape = (*src.shape[:ydim], *dst_geobox.shape, *src.shape[ydim + 2 :]) - src_nodata = kw.pop("src_nodata", None) - if src_nodata is None: - src_nodata = src.odc.nodata + src_nodata = resolve_nodata(kw.pop("src_nodata", "auto"), src.dtype, src.odc.nodata) + dst_nodata = resolve_nodata(dst_nodata, dtype, src_nodata) fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype) @@ -865,10 +866,9 @@ def _xr_reproject_da( ) attrs = {k: v for k, v in src.attrs.items() if k not in REPROJECT_SKIP_ATTRS} - if numpy.isfinite(fill_value) and ( - dst_nodata is not None or src_nodata is not None - ): - attrs.update({k: maybe_int(float(fill_value), 1e-6) for k in NODATA_ATTRIBUTES}) + if not is_nodata_empty(dst_nodata): + assert dst_nodata is not None + attrs.update({k: maybe_int(float(dst_nodata), 1e-6) for k in NODATA_ATTRIBUTES}) # new set of coords (replace x,y dims) # discard all coords that reference spatial dimensions @@ -997,7 +997,7 @@ def assign_crs( return assign_crs(self._xx, crs=crs, crs_coord_name=crs_coord_name) @property - def nodata(self) -> Optional[float]: + def nodata(self) -> Nodata: """Extract ``nodata/_FillValue`` attribute if set.""" attrs = self._xx.attrs for k in ["nodata", "_FillValue"]: @@ -1076,7 +1076,7 @@ def wrap_xr( gbox: SomeGeoBox, *, time=None, - nodata=None, + nodata: MaybeAutoNodata = "auto", crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME, always_yx: bool = False, dims: Optional[Tuple[str, ...]] = None, @@ -1159,8 +1159,9 @@ def _postfix_dims(n): [f"b{i}" for i in range(nb)], dims=(dim,), name=dim ) - if nodata is not None: - attrs = {"nodata": nodata, **attrs} + _nodata = resolve_nodata(nodata, im.dtype) + if not is_nodata_empty(_nodata) or nodata != "auto": + attrs = {"nodata": _nodata, **attrs} out = xarray.DataArray(im, coords=coords, dims=dims, attrs=attrs) if crs_coord_name is not None: diff --git a/odc/geo/cog/_rio.py b/odc/geo/cog/_rio.py index b8b6dd78..b91efecb 100644 --- a/odc/geo/cog/_rio.py +++ b/odc/geo/cog/_rio.py @@ -19,7 +19,8 @@ from rasterio.shutil import copy as rio_copy # pylint: disable=no-name-in-module from ..geobox import GeoBox -from ..types import MaybeNodata, SomeShape, shape_, wh_ +from ..math import resolve_nodata +from ..types import MaybeAutoNodata, Nodata, SomeShape, shape_, wh_ from ..warp import resampling_s2rio from ._shared import adjust_blocksize @@ -123,7 +124,7 @@ def _write_cog( pix: np.ndarray, geobox: GeoBox, fname: Union[Path, str], - nodata: MaybeNodata = None, + nodata: Nodata = None, overwrite: bool = False, blocksize: Optional[int] = None, overview_resampling: Optional[str] = None, @@ -283,6 +284,7 @@ def write_cog( use_windowed_writes: bool = False, intermediate_compression: Union[bool, str, Dict[str, Any]] = False, tags: Optional[Dict[str, Any]] = None, + nodata: MaybeAutoNodata = "auto", **extra_rio_opts, ) -> Union[Path, bytes]: """ @@ -298,7 +300,7 @@ def write_cog( :param overview_levels: List of shrink factors to compute overiews for: [2,4,8,16,32], to disable overviews supply empty list ``[]`` :param nodata: Set ``nodata`` flag to this value if supplied, by default ``nodata`` is - read from the attributes of the input array (``geo_im.attrs['nodata']``). + read from the attributes of the input array (``geo_im.odc.nodata``). :param use_windowed_writes: Write image block by block (might need this for large images) :param intermediate_compression: Configure compression settings for first pass write , default is no compression @@ -322,6 +324,8 @@ def write_cog( This means that this function will use about 1.5 to 2 times memory taken by ``geo_im``. """ + nodata = resolve_nodata(nodata, geo_im.dtype, geo_im.odc.nodata) + if overviews is not None: layers = [geo_im, *overviews] result = write_cog_layers( @@ -333,6 +337,7 @@ def write_cog( use_windowed_writes=use_windowed_writes, intermediate_compression=intermediate_compression, tags=tags, + nodata=nodata, **extra_rio_opts, ) assert result is not None @@ -340,9 +345,6 @@ def write_cog( pix = geo_im.data geobox = geo_im.odc.geobox - nodata = extra_rio_opts.pop("nodata", None) - if nodata is None: - nodata = geo_im.attrs.get("nodata", None) if geobox is None: raise ValueError("Need geo-registered array on input") @@ -448,6 +450,7 @@ def write_cog_layers( intermediate_compression: Union[bool, str, Dict[str, Any]] = False, use_windowed_writes: bool = False, tags: Optional[Dict[str, Any]] = None, + nodata: Nodata = None, **extra_rio_opts, ) -> Union[Path, bytes, None]: """ @@ -475,14 +478,14 @@ def write_cog_layers( blocksize=blocksize, shape=gbox.shape, is_float=pix.dtype.kind == "f", - nodata=pix.attrs.get("nodata", None), + nodata=nodata, ) rio_opts.update(extra_rio_opts) first_pass_cfg: Dict[str, Any] = { "num_threads": "ALL_CPUS", "blocksize": blocksize, - "nodata": rio_opts.get("nodata", None), + "nodata": nodata, "use_windowed_writes": use_windowed_writes, "gdal_metadata": _get_gdal_metadata(xx, tags), **_norm_compression_opts(intermediate_compression), diff --git a/odc/geo/cog/_shared.py b/odc/geo/cog/_shared.py index 7f6622dd..c495913f 100644 --- a/odc/geo/cog/_shared.py +++ b/odc/geo/cog/_shared.py @@ -12,7 +12,7 @@ from ..geobox import GeoBox from ..math import align_down_pow2, align_up -from ..types import MaybeNodata, Shape2d, SomeShape, shape_, wh_ +from ..types import Nodata, Shape2d, SomeShape, shape_, wh_ # pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements,too-many-instance-attributes @@ -63,7 +63,7 @@ class CogMeta: compressionargs: Dict[str, Any] = field(default_factory=dict, repr=False) gbox: Optional[GeoBox] = None overviews: Tuple["CogMeta", ...] = field(default=(), repr=False) - nodata: MaybeNodata = None + nodata: Nodata = None def _pix_shape(self, shape: Shape2d) -> Tuple[int, ...]: if self.axis == "YX": diff --git a/odc/geo/cog/_tifffile.py b/odc/geo/cog/_tifffile.py index c35aec51..8d2f00ff 100644 --- a/odc/geo/cog/_tifffile.py +++ b/odc/geo/cog/_tifffile.py @@ -18,7 +18,8 @@ from .._interop import have from ..geobox import GeoBox -from ..types import MaybeNodata, Shape2d, Unset, shape_ +from ..math import resolve_nodata +from ..types import MaybeAutoNodata, Shape2d, Unset, shape_ from ._mpu import mpu_write from ._mpu_fs import MPUFileSink from ._s3 import MultiPartUpload, s3_parse_url @@ -120,7 +121,7 @@ def _make_empty_cog( dtype: Any, gbox: Optional[GeoBox] = None, *, - nodata: MaybeNodata = None, + nodata: MaybeAutoNodata = "auto", gdal_metadata: Optional[str] = None, compression: Union[str, Unset] = Unset(), compressionargs: Any = None, @@ -140,6 +141,8 @@ def _make_empty_cog( enumarg, ) + nodata = resolve_nodata(nodata, dtype) + predictor, compression, compressionargs = _norm_compression_tifffile( dtype, predictor, @@ -773,7 +776,7 @@ def save_cog_with_dask( def geotiff_metadata( geobox: GeoBox, - nodata: MaybeNodata = None, + nodata: MaybeAutoNodata = "auto", gdal_metadata: Optional[str] = None, ) -> Tuple[List[Tuple[int, int, int, Any]], Dict[str, Any]]: """ diff --git a/odc/geo/math.py b/odc/geo/math.py index 60c04842..53d0d59d 100644 --- a/odc/geo/math.py +++ b/odc/geo/math.py @@ -29,6 +29,9 @@ from .types import ( XY, AnchorEnum, + FillValue, + MaybeAutoNodata, + Nodata, Resolution, SomeResolution, SomeShape, @@ -181,6 +184,49 @@ def is_almost_int(x: float, tol: float) -> bool: return x < tol +def resolve_fill_value(dst_nodata: Nodata, src_nodata: Nodata, dtype) -> FillValue: + dtype = np.dtype(dtype) + + if dst_nodata is not None: + return dtype.type(dst_nodata) + if np.issubdtype(dtype, np.floating): + return dtype.type("nan") + if src_nodata is not None: + return dtype.type(src_nodata) + return dtype.type(0) + + +def resolve_nodata( + nodata: MaybeAutoNodata, + dtype=None, + xr_nodata=None, +) -> Nodata: + # pylint: disable=too-many-return-statements + if nodata is None: + return None + + if nodata == "auto": + if xr_nodata is not None: + return xr_nodata + if dtype is None: + return None + if np.issubdtype(dtype, np.floating): + return np.nan + return None + + if isinstance(nodata, str): + return float(nodata) + return nodata + + +def is_nodata_empty(nodata: Nodata) -> bool: + if nodata is None: + return True + if isinstance(nodata, float) and np.isnan(nodata): + return True + return False + + def _snap_edge_pos(x0: float, x1: float, res: float, tol: float) -> Tuple[float, int]: assert res > 0 assert x1 >= x0 diff --git a/odc/geo/types.py b/odc/geo/types.py index 773f32a0..0cfe51d9 100644 --- a/odc/geo/types.py +++ b/odc/geo/types.py @@ -21,8 +21,10 @@ MaybeInt = Optional[int] MaybeFloat = Optional[float] -Nodata = Union[float, int, str] -MaybeNodata = Optional[Nodata] +FillValue = Union[float, int] +Nodata = Union[float, int, None] +MaybeNodata = Union[float, int, str, None] +MaybeAutoNodata = Union[float, int, str, None, Literal["auto"]] T = TypeVar("T") T1 = TypeVar("T1") T2 = TypeVar("T2") diff --git a/odc/geo/warp.py b/odc/geo/warp.py index b441d01b..91a13ec7 100644 --- a/odc/geo/warp.py +++ b/odc/geo/warp.py @@ -10,19 +10,18 @@ from .gcp import GCPGeoBox from .geobox import GeoBox -from .types import wh_ +from .math import resolve_fill_value, resolve_nodata +from .types import MaybeAutoNodata, Nodata, wh_ # pylint: disable=invalid-name, too-many-arguments Resampling = Union[str, int, rasterio.warp.Resampling] -Nodata = Optional[Union[int, float]] _WRP_CRS = "epsg:3857" __all__ = [ "resampling_s2rio", "is_resampling_nn", "resolve_fill_value", - "warp_affine", - "warp_affine_rio", + "rio_warp_affine", "rio_reproject", ] @@ -47,25 +46,13 @@ def is_resampling_nn(resampling: Resampling) -> bool: return resampling == rasterio.warp.Resampling.nearest -def resolve_fill_value(dst_nodata, src_nodata, dtype): - dtype = np.dtype(dtype) - - if dst_nodata is not None: - return dtype.type(dst_nodata) - if np.issubdtype(dtype, np.floating): - return dtype.type("nan") - if src_nodata is not None: - return dtype.type(src_nodata) - return dtype.type(0) - - -def warp_affine_rio( +def rio_warp_affine( src: np.ndarray, dst: np.ndarray, A: Affine, resampling: Resampling, - src_nodata: Nodata = None, - dst_nodata: Nodata = None, + src_nodata: MaybeAutoNodata = "auto", + dst_nodata: MaybeAutoNodata = "auto", **kwargs, ) -> np.ndarray: """ @@ -89,36 +76,13 @@ def warp_affine_rio( s_gbox = GeoBox(wh_(sw, sh), Affine.identity(), _WRP_CRS) d_gbox = GeoBox(wh_(dw, dh), A, _WRP_CRS) - return _rio_reproject( - src, dst, s_gbox, d_gbox, resampling, src_nodata, dst_nodata, **kwargs - ) - - -def warp_affine( - src: np.ndarray, - dst: np.ndarray, - A: Affine, - resampling: Resampling, - src_nodata: Nodata = None, - dst_nodata: Nodata = None, - **kwargs, -) -> np.ndarray: - """ - Perform Affine warp using best available backend (GDAL via rasterio is the only one so far). - :param src: image as ndarray - :param dst: image as ndarray - :param A: Affine transformm, maps from dst_coords to src_coords - :param resampling: str resampling strategy - :param src_nodata: Value representing "no data" in the source image - :param dst_nodata: Value to represent "no data" in the destination image - - :param kwargs: any other args to pass to implementation + src_nodata = resolve_nodata(src_nodata, src.dtype) + dst_nodata = resolve_nodata(dst_nodata, dst.dtype) + fill_value = resolve_fill_value(dst_nodata, src_nodata, dst.dtype) - :returns: dst - """ - return warp_affine_rio( - src, dst, A, resampling, src_nodata=src_nodata, dst_nodata=dst_nodata, **kwargs + return _rio_reproject( + src, dst, s_gbox, d_gbox, resampling, src_nodata, fill_value, **kwargs ) @@ -128,8 +92,8 @@ def rio_reproject( s_gbox: Union[GeoBox, GCPGeoBox], d_gbox: GeoBox, resampling: Resampling, - src_nodata: Nodata = None, - dst_nodata: Nodata = None, + src_nodata: MaybeAutoNodata = "auto", + dst_nodata: MaybeAutoNodata = "auto", ydim: Optional[int] = None, **kwargs, ) -> np.ndarray: @@ -151,13 +115,22 @@ def rio_reproject( """ assert src.ndim == dst.ndim + src_nodata = resolve_nodata(src_nodata, src.dtype) + dst_nodata = resolve_nodata(dst_nodata, dst.dtype) + fill_value = resolve_fill_value(dst_nodata, src_nodata, dst.dtype) + if src.ndim == 2: return _rio_reproject( - src, dst, s_gbox, d_gbox, resampling, src_nodata, dst_nodata, **kwargs + src, + dst, + s_gbox, + d_gbox, + resampling=resampling, + src_nodata=src_nodata, + dst_nodata=fill_value, + **kwargs, ) - fill_value = resolve_fill_value(dst_nodata, src_nodata, dst.dtype) - if ydim is None: # Assume last two dimensions are Y/X ydim = src.ndim - 2 diff --git a/tests/conftest.py b/tests/conftest.py index aecb750d..0e623893 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from pathlib import Path +import numpy as np import pytest import xarray as xr @@ -7,6 +8,8 @@ from odc.geo.geobox import GeoBox from odc.geo.xr import rasterize +# pylint: disable=protected-access,import-outside-toplevel,redefined-outer-name + @pytest.fixture(scope="session") def data_dir(): @@ -45,4 +48,18 @@ def crs(): @pytest.fixture() def country(iso3, crs): - return country_geom(iso3, crs=crs) + yield country_geom(iso3, crs=crs) + + +@pytest.fixture() +def country_raster(country, resolution): + geobox = GeoBox.from_geopolygon(country, resolution=resolution, tight=True) + yield rasterize(country, geobox) + + +@pytest.fixture() +def country_raster_f32(country, resolution): + geobox = GeoBox.from_geopolygon(country, resolution=resolution, tight=True) + xx = rasterize(country, geobox) + xx = xr.where(xx, np.random.uniform(0, 100, xx.shape).astype("float32"), 0) + yield xx diff --git a/tests/test_cog.py b/tests/test_cog.py index 4f2a23e9..1d0991c3 100644 --- a/tests/test_cog.py +++ b/tests/test_cog.py @@ -345,7 +345,7 @@ def test_norm_compress(): _gbox.center_pixel.pad(3), ], ) -@pytest.mark.parametrize("nodata", [None, float("nan"), 0, -999]) +@pytest.mark.parametrize("nodata", ["auto", None, float("nan"), 0, -999]) @pytest.mark.parametrize("gdal_metadata", [None, ""]) def test_geotiff_metadata(gbox: GeoBox, nodata, gdal_metadata: Optional[str]): assert gbox.crs is not None diff --git a/tests/test_map.py b/tests/test_map.py index 57e8e9df..8b2ad9e5 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -5,6 +5,8 @@ from odc.geo._interop import have from odc.geo.xr import ODCExtensionDa +# pylint: disable=protected-access,import-outside-toplevel + cmap = np.asarray( [ [153, 153, 102, 255], diff --git a/tests/test_math.py b/tests/test_math.py index 7e5829c4..15d958d3 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -2,6 +2,7 @@ from typing import Tuple import numpy as np +import numpy.testing as npt import pytest from affine import Affine @@ -17,10 +18,13 @@ apply_affine, data_resolution_and_offset, is_almost_int, + is_nodata_empty, maybe_int, maybe_zero, quasi_random_r2, resolution_from_affine, + resolve_fill_value, + resolve_nodata, snap_affine, snap_grid, snap_scale, @@ -29,6 +33,8 @@ ) from odc.geo.testutils import mkA +NaN = float("nan") + def test_math_ops(): assert align_up(32, 16) == 32 @@ -391,3 +397,49 @@ def test_align_down_pow2(x: int): assert isinstance(y, int) assert y <= x assert 2 ** int(math.log2(y)) == y + + +@pytest.mark.parametrize( + "nodata, dtype, expect", + [ + (None, None, None), + (None, "float32", None), + ("auto", None, None), + ("auto", "float32", NaN), + ("auto", np.dtype("float64"), NaN), + ("auto", "uint32", None), + ("auto", "int16", None), + ("auto", "uint8", None), + ("auto", "bool", None), + ], +) +def test_resolve_nodata(nodata, dtype, expect): + npt.assert_equal(resolve_nodata(nodata, dtype), expect) + assert resolve_nodata("auto", dtype, 13) == 13 + + +@pytest.mark.parametrize( + "dst_nodata, src_nodata, dtype, expect", + [ + (None, None, "uint16", 0), + (None, 3, "uint16", 3), + (4, 3, "uint16", 4), + (10, 0, "uint16", 10), + (NaN, 0, "float32", NaN), + (None, 0, "float32", NaN), + (None, None, "bool", False), + (None, None, "uint8", 0), + ], +) +def test_resolve_fill_value(dst_nodata, src_nodata, dtype, expect): + npt.assert_equal(resolve_fill_value(dst_nodata, src_nodata, dtype), expect) + + +def test_empty(): + assert is_nodata_empty(None) is True + assert is_nodata_empty(NaN) is True + assert is_nodata_empty(np.nan) is True + assert is_nodata_empty(0) is False + assert is_nodata_empty(-1) is False + assert is_nodata_empty(-1.0) is False + assert is_nodata_empty(np.uint8(0)) is False diff --git a/tests/test_warp.py b/tests/test_warp.py new file mode 100644 index 00000000..4d5ee5ed --- /dev/null +++ b/tests/test_warp.py @@ -0,0 +1,68 @@ +import numpy as np +import numpy.testing as npt +import pytest +import xarray as xr +from affine import Affine + +from odc.geo import MaybeCRS +from odc.geo.warp import resampling_s2rio, rio_reproject, rio_warp_affine + +NaN = float("nan") + + +@pytest.mark.parametrize( + "iso3, crs, resolution", + [ + ("AUS", "epsg:4326", 0.1), + ("AUS", "epsg:3577", 10_000), + ("AUS", "epsg:3857", 10_000), + ("NZL", "epsg:3857", 5_000), + ], +) +@pytest.mark.parametrize("resampling", ["nearest", "bilinear", "average", "sum"]) +def test_warp_nan(country_raster_f32: xr.DataArray, crs: MaybeCRS, resampling: str): + xx = country_raster_f32 + assert isinstance(xx, xr.DataArray) + assert xx.odc.crs == crs + assert xx.odc.nodata is None + assert xx.dtype == "float32" + + mid = xx.shape[0] // 2 + xx.data[mid, :] = NaN + xx.data[:, -10] = NaN + + assert resampling_s2rio(resampling) is not None + assert np.isnan(xx.data).sum() > 0 + + src_gbox = xx.odc.geobox + dst_gbox = src_gbox.zoom_to(shape=100).pad(10) + + yy1 = np.full(dst_gbox.shape, -333, dtype=xx.dtype) + yy2 = np.full(dst_gbox.shape, -333, dtype=xx.dtype) + + assert rio_reproject(xx.data, yy1, src_gbox, dst_gbox, resampling=resampling) is yy1 + assert ( + rio_reproject( + xx.data, + yy2, + src_gbox, + dst_gbox, + resampling=resampling, + src_nodata=NaN, + dst_nodata=NaN, + ) + is yy2 + ) + + npt.assert_array_equal(yy1, yy2) + + # make sure all pixels were replaced + assert (yy1 == -333).sum() == 0 + + # expect to see NaNs in the output + assert np.isnan(yy2).sum() > 0 + + A = Affine.identity() + xx_ = xx.data.copy() * 0 + assert rio_warp_affine(xx.data, xx_, A, resampling) is xx_ + npt.assert_array_equal(xx.data, xx_) diff --git a/tests/test_xr_interop.py b/tests/test_xr_interop.py index 1b8272b0..24601809 100644 --- a/tests/test_xr_interop.py +++ b/tests/test_xr_interop.py @@ -32,6 +32,8 @@ GeoBox.from_bbox((-10, -2, 5, 4), "epsg:3857", tight=True, resolution=resxy_(1, 2)), ] +NaN = float("nan") + @pytest.fixture def geobox_epsg4326(): @@ -361,7 +363,7 @@ def test_wrap_xr(): assert wrap_xr(data, gbox, always_yx=True).dims == ("y", "x") assert wrap_xr(data, gbox, dims=("Y", "X")).dims == ("Y", "X") - xx = wrap_xr(data, gbox, nodata=None) + xx = wrap_xr(data, gbox) assert xx.attrs == {} xx = wrap_xr(data, gbox, nodata=10, some_flag=3) @@ -385,6 +387,9 @@ def test_wrap_xr(): assert xx.shape == (*gbox.shape, 1) assert xx.band.data.tolist() == ["b0"] + assert wrap_xr(data.astype("float32"), gbox).odc.nodata is None + assert np.isnan(wrap_xr(data.astype("float32"), gbox, nodata=NaN).odc.nodata) + @pytest.mark.parametrize("gbox", TEST_GEOBOXES_SMALL_AXIS_ALIGNED) @pytest.mark.parametrize("nprefix", [0, 1, 2])