Skip to content

Commit

Permalink
refactor: nodata handling #162
Browse files Browse the repository at this point in the history
- adding new types
  - `MaybeAutoNodata` -- `None|int|float|str|"auto"`
  - `Nodata` is now `None|float|int`
  - `MaybeNodata` is now `None|float|int|str
  - "auto" replaces what used to be `None`
  - `None` now means "no nodata value"
- `resolve_nodata()` is used to handle nodata
   options consistently across the library
- default nodata for float is `nan`
- fixes in reproject/overview generation for float
  data. It is assumed to have `nan` values, and
  GDAL needs `nodata=nan` to handle it correctly.
  • Loading branch information
Kirill888 committed Jun 18, 2024
1 parent 04d2f11 commit fc2ea18
Show file tree
Hide file tree
Showing 15 changed files with 279 additions and 99 deletions.
10 changes: 3 additions & 7 deletions odc/geo/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
from ._blocks import BlockAssembler
from .gcp import GCPGeoBox
from .geobox import GeoBox, GeoboxTiles
from .warp import (
Nodata,
Resampling,
_rio_reproject,
resampling_s2rio,
resolve_fill_value,
)
from .math import resolve_fill_value
from .types import Nodata
from .warp import Resampling, _rio_reproject, resampling_s2rio


def _do_chunked_reproject(
Expand Down
24 changes: 18 additions & 6 deletions odc/geo/_rgba.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import xarray as xr

from ._interop import is_dask_collection
from .types import Nodata

# pylint: disable=import-outside-toplevel

Expand Down Expand Up @@ -59,15 +60,15 @@ def _np_to_rgba(
r: np.ndarray,
g: np.ndarray,
b: np.ndarray,
nodata: Optional[float],
nodata: Nodata,
vmin: float,
vmax: float,
) -> np.ndarray:
rgba = np.zeros((*r.shape, 4), dtype="uint8")

if r.dtype.kind == "f":
valid = ~np.isnan(r)
if nodata is not None:
if nodata is not None and not np.isnan(nodata):
valid = valid * (r != nodata)
elif nodata is not None:
valid = r != nodata
Expand Down Expand Up @@ -130,7 +131,7 @@ def to_rgba(
assert vmax is not None

_b = ds[bands[0]]
nodata = getattr(_b, "nodata", None)
nodata = _b.odc.nodata
dims = (*_b.dims, "band")

r, g, b = (ds[name].data for name in bands)
Expand Down Expand Up @@ -171,12 +172,20 @@ def _np_colorize(x, cmap, clip):
return cmap[x]


def _matplotlib_colorize(x, cmap, vmin=None, vmax=None, nodata=None, robust=False):
def _matplotlib_colorize(
x,
cmap,
vmin=None,
vmax=None,
nodata: Nodata = None,
robust=False,
):
from matplotlib import colormaps
from matplotlib.colors import Normalize

if cmap is None or isinstance(cmap, str):
cmap = colormaps.get_cmap(cmap)
# None is a valid input, maps to default cmap
cmap = colormaps.get_cmap(cmap) # type: ignore

if nodata is not None:
x = np.where(x == nodata, np.float32("nan"), x)
Expand Down Expand Up @@ -234,8 +243,11 @@ def colorize(
:param clip: If ``True`` clip values from ``x`` to be in the safe range for ``cmap``.
"""
# pylint: disable=too-many-locals
from ._xr_interop import ODCExtensionDa

assert isinstance(x, xr.DataArray)
assert isinstance(x.odc, ODCExtensionDa)

_is_dask = is_dask_collection(x.data)

if isinstance(cmap, np.ndarray):
Expand Down Expand Up @@ -263,7 +275,7 @@ def colorize(
_matplotlib_colorize,
vmin=vmin,
vmax=vmax,
nodata=getattr(x, "nodata", None),
nodata=x.odc.nodata,
robust=robust,
)
nc, cmap_dtype = 4, "uint8"
Expand Down
33 changes: 17 additions & 16 deletions odc/geo/_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
affine_from_axis,
approx_equal_affine,
is_affine_st,
is_nodata_empty,
maybe_int,
resolution_from_affine,
resolve_fill_value,
resolve_nodata,
)
from .overlap import compute_output_geobox
from .roi import roi_is_empty
from .types import Resolution, SomeResolution, SomeShape, xy_
from .warp import resolve_fill_value
from .types import MaybeAutoNodata, Nodata, Resolution, SomeResolution, SomeShape, xy_

# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-lines
Expand Down Expand Up @@ -656,7 +658,7 @@ def xr_reproject(
how: Union[SomeCRS, GeoBox],
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dst_nodata: MaybeAutoNodata = "auto",
dtype=None,
resolution: Union[SomeResolution, Literal["auto", "fit", "same"]] = "auto",
shape: Union[SomeShape, int, None] = None,
Expand Down Expand Up @@ -753,7 +755,7 @@ def _xr_reproject_ds(
how: Union[SomeCRS, GeoBox],
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dst_nodata: MaybeAutoNodata = "auto",
dtype=None,
**kw,
) -> xarray.Dataset:
Expand Down Expand Up @@ -797,7 +799,7 @@ def _xr_reproject_da(
how: Union[SomeCRS, GeoBox],
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dst_nodata: MaybeAutoNodata = "auto",
dtype=None,
**kw,
) -> xarray.DataArray:
Expand Down Expand Up @@ -828,9 +830,8 @@ def _xr_reproject_da(
assert ydim + 1 == src.odc.xdim
dst_shape = (*src.shape[:ydim], *dst_geobox.shape, *src.shape[ydim + 2 :])

src_nodata = kw.pop("src_nodata", None)
if src_nodata is None:
src_nodata = src.odc.nodata
src_nodata = resolve_nodata(kw.pop("src_nodata", "auto"), src.dtype, src.odc.nodata)
dst_nodata = resolve_nodata(dst_nodata, dtype, src_nodata)

fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype)

Expand Down Expand Up @@ -865,10 +866,9 @@ def _xr_reproject_da(
)

attrs = {k: v for k, v in src.attrs.items() if k not in REPROJECT_SKIP_ATTRS}
if numpy.isfinite(fill_value) and (
dst_nodata is not None or src_nodata is not None
):
attrs.update({k: maybe_int(float(fill_value), 1e-6) for k in NODATA_ATTRIBUTES})
if not is_nodata_empty(dst_nodata):
assert dst_nodata is not None
attrs.update({k: maybe_int(float(dst_nodata), 1e-6) for k in NODATA_ATTRIBUTES})

# new set of coords (replace x,y dims)
# discard all coords that reference spatial dimensions
Expand Down Expand Up @@ -997,7 +997,7 @@ def assign_crs(
return assign_crs(self._xx, crs=crs, crs_coord_name=crs_coord_name)

@property
def nodata(self) -> Optional[float]:
def nodata(self) -> Nodata:
"""Extract ``nodata/_FillValue`` attribute if set."""
attrs = self._xx.attrs
for k in ["nodata", "_FillValue"]:
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def wrap_xr(
gbox: SomeGeoBox,
*,
time=None,
nodata=None,
nodata: MaybeAutoNodata = "auto",
crs_coord_name: Optional[str] = _DEFAULT_CRS_COORD_NAME,
always_yx: bool = False,
dims: Optional[Tuple[str, ...]] = None,
Expand Down Expand Up @@ -1159,8 +1159,9 @@ def _postfix_dims(n):
[f"b{i}" for i in range(nb)], dims=(dim,), name=dim
)

if nodata is not None:
attrs = {"nodata": nodata, **attrs}
_nodata = resolve_nodata(nodata, im.dtype)
if not is_nodata_empty(_nodata) or nodata != "auto":
attrs = {"nodata": _nodata, **attrs}

out = xarray.DataArray(im, coords=coords, dims=dims, attrs=attrs)
if crs_coord_name is not None:
Expand Down
19 changes: 11 additions & 8 deletions odc/geo/cog/_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from rasterio.shutil import copy as rio_copy # pylint: disable=no-name-in-module

from ..geobox import GeoBox
from ..types import MaybeNodata, SomeShape, shape_, wh_
from ..math import resolve_nodata
from ..types import MaybeAutoNodata, Nodata, SomeShape, shape_, wh_
from ..warp import resampling_s2rio
from ._shared import adjust_blocksize

Expand Down Expand Up @@ -123,7 +124,7 @@ def _write_cog(
pix: np.ndarray,
geobox: GeoBox,
fname: Union[Path, str],
nodata: MaybeNodata = None,
nodata: Nodata = None,
overwrite: bool = False,
blocksize: Optional[int] = None,
overview_resampling: Optional[str] = None,
Expand Down Expand Up @@ -283,6 +284,7 @@ def write_cog(
use_windowed_writes: bool = False,
intermediate_compression: Union[bool, str, Dict[str, Any]] = False,
tags: Optional[Dict[str, Any]] = None,
nodata: MaybeAutoNodata = "auto",
**extra_rio_opts,
) -> Union[Path, bytes]:
"""
Expand All @@ -298,7 +300,7 @@ def write_cog(
:param overview_levels: List of shrink factors to compute overiews for: [2,4,8,16,32],
to disable overviews supply empty list ``[]``
:param nodata: Set ``nodata`` flag to this value if supplied, by default ``nodata`` is
read from the attributes of the input array (``geo_im.attrs['nodata']``).
read from the attributes of the input array (``geo_im.odc.nodata``).
:param use_windowed_writes: Write image block by block (might need this for large images)
:param intermediate_compression: Configure compression settings for first pass write
, default is no compression
Expand All @@ -322,6 +324,8 @@ def write_cog(
This means that this function will use about 1.5 to 2 times memory taken by ``geo_im``.
"""
nodata = resolve_nodata(nodata, geo_im.dtype, geo_im.odc.nodata)

if overviews is not None:
layers = [geo_im, *overviews]
result = write_cog_layers(
Expand All @@ -333,16 +337,14 @@ def write_cog(
use_windowed_writes=use_windowed_writes,
intermediate_compression=intermediate_compression,
tags=tags,
nodata=nodata,
**extra_rio_opts,
)
assert result is not None
return result

pix = geo_im.data
geobox = geo_im.odc.geobox
nodata = extra_rio_opts.pop("nodata", None)
if nodata is None:
nodata = geo_im.attrs.get("nodata", None)

if geobox is None:
raise ValueError("Need geo-registered array on input")
Expand Down Expand Up @@ -448,6 +450,7 @@ def write_cog_layers(
intermediate_compression: Union[bool, str, Dict[str, Any]] = False,
use_windowed_writes: bool = False,
tags: Optional[Dict[str, Any]] = None,
nodata: Nodata = None,
**extra_rio_opts,
) -> Union[Path, bytes, None]:
"""
Expand Down Expand Up @@ -475,14 +478,14 @@ def write_cog_layers(
blocksize=blocksize,
shape=gbox.shape,
is_float=pix.dtype.kind == "f",
nodata=pix.attrs.get("nodata", None),
nodata=nodata,
)
rio_opts.update(extra_rio_opts)

first_pass_cfg: Dict[str, Any] = {
"num_threads": "ALL_CPUS",
"blocksize": blocksize,
"nodata": rio_opts.get("nodata", None),
"nodata": nodata,
"use_windowed_writes": use_windowed_writes,
"gdal_metadata": _get_gdal_metadata(xx, tags),
**_norm_compression_opts(intermediate_compression),
Expand Down
4 changes: 2 additions & 2 deletions odc/geo/cog/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..geobox import GeoBox
from ..math import align_down_pow2, align_up
from ..types import MaybeNodata, Shape2d, SomeShape, shape_, wh_
from ..types import Nodata, Shape2d, SomeShape, shape_, wh_

# pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements,too-many-instance-attributes

Expand Down Expand Up @@ -63,7 +63,7 @@ class CogMeta:
compressionargs: Dict[str, Any] = field(default_factory=dict, repr=False)
gbox: Optional[GeoBox] = None
overviews: Tuple["CogMeta", ...] = field(default=(), repr=False)
nodata: MaybeNodata = None
nodata: Nodata = None

def _pix_shape(self, shape: Shape2d) -> Tuple[int, ...]:
if self.axis == "YX":
Expand Down
9 changes: 6 additions & 3 deletions odc/geo/cog/_tifffile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from .._interop import have
from ..geobox import GeoBox
from ..types import MaybeNodata, Shape2d, Unset, shape_
from ..math import resolve_nodata
from ..types import MaybeAutoNodata, Shape2d, Unset, shape_
from ._mpu import mpu_write
from ._mpu_fs import MPUFileSink
from ._s3 import MultiPartUpload, s3_parse_url
Expand Down Expand Up @@ -120,7 +121,7 @@ def _make_empty_cog(
dtype: Any,
gbox: Optional[GeoBox] = None,
*,
nodata: MaybeNodata = None,
nodata: MaybeAutoNodata = "auto",
gdal_metadata: Optional[str] = None,
compression: Union[str, Unset] = Unset(),
compressionargs: Any = None,
Expand All @@ -140,6 +141,8 @@ def _make_empty_cog(
enumarg,
)

nodata = resolve_nodata(nodata, dtype)

predictor, compression, compressionargs = _norm_compression_tifffile(
dtype,
predictor,
Expand Down Expand Up @@ -773,7 +776,7 @@ def save_cog_with_dask(

def geotiff_metadata(
geobox: GeoBox,
nodata: MaybeNodata = None,
nodata: MaybeAutoNodata = "auto",
gdal_metadata: Optional[str] = None,
) -> Tuple[List[Tuple[int, int, int, Any]], Dict[str, Any]]:
"""
Expand Down
46 changes: 46 additions & 0 deletions odc/geo/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from .types import (
XY,
AnchorEnum,
FillValue,
MaybeAutoNodata,
Nodata,
Resolution,
SomeResolution,
SomeShape,
Expand Down Expand Up @@ -181,6 +184,49 @@ def is_almost_int(x: float, tol: float) -> bool:
return x < tol


def resolve_fill_value(dst_nodata: Nodata, src_nodata: Nodata, dtype) -> FillValue:
dtype = np.dtype(dtype)

if dst_nodata is not None:
return dtype.type(dst_nodata)
if np.issubdtype(dtype, np.floating):
return dtype.type("nan")
if src_nodata is not None:
return dtype.type(src_nodata)
return dtype.type(0)


def resolve_nodata(
nodata: MaybeAutoNodata,
dtype=None,
xr_nodata=None,
) -> Nodata:
# pylint: disable=too-many-return-statements
if nodata is None:
return None

if nodata == "auto":
if xr_nodata is not None:
return xr_nodata
if dtype is None:
return None
if np.issubdtype(dtype, np.floating):
return np.nan
return None

if isinstance(nodata, str):
return float(nodata)
return nodata


def is_nodata_empty(nodata: Nodata) -> bool:
if nodata is None:
return True
if isinstance(nodata, float) and np.isnan(nodata):
return True
return False


def _snap_edge_pos(x0: float, x1: float, res: float, tol: float) -> Tuple[float, int]:
assert res > 0
assert x1 >= x0
Expand Down
Loading

0 comments on commit fc2ea18

Please sign in to comment.