Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
fbunt committed Feb 10, 2024
2 parents fe19886 + 01ee4ac commit 6f8bba2
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 91 deletions.
1 change: 1 addition & 0 deletions docs/reference/raster.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Null Data and Remapping
Raster.replace_null
Raster.remap_range
Raster.set_null_value
Raster.set_null
Raster.to_null_mask
Raster.where

Expand Down
98 changes: 61 additions & 37 deletions raster_tools/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,10 +949,10 @@ def band_concat(rasters):


@nb.jit(nopython=True, nogil=True)
def _remap_values(x, mask, mappings, inclusivity):
def _remap_values(x, mask, start_end_values, new_values, inclusivity):
outx = np.zeros_like(x)
bands, rows, columns = x.shape
rngs = mappings.shape[0]
rngs = start_end_values.shape[0]
for bnd in range(bands):
for rw in range(rows):
for cl in range(columns):
Expand All @@ -961,7 +961,8 @@ def _remap_values(x, mask, mappings, inclusivity):
vl = x[bnd, rw, cl]
remap = False
for imap in range(rngs):
left, right, new = mappings[imap]
left, right = start_end_values[imap]
new = new_values[imap]
if inclusivity == 0:
remap = left <= vl < right
elif inclusivity == 1:
Expand All @@ -986,7 +987,7 @@ def _normalize_mappings(mappings):
)
if not len(mappings):
raise ValueError("No mappings provided")
if len(mappings) and is_scalar(mappings[0]):
if is_scalar(mappings[0]):
mappings = [mappings]
try:
mappings = [list(m) for m in mappings]
Expand All @@ -995,22 +996,33 @@ def _normalize_mappings(mappings):
"Mappings must be either single 3-tuple or list of 3-tuples of "
"scalars"
) from None
starts = []
ends = []
news = []
for m in mappings:
if len(m) != 3:
raise ValueError(
"Mappings must be either single 3-tuple or list of 3-tuples of"
" scalars"
)
if not all(is_scalar(mi) for mi in m):
raise TypeError("Mappings values must be scalars")
if not all(is_scalar(mi) for mi in m[:2]):
raise TypeError("Mapping min and max values must be scalars")
if any(np.isnan(mi) for mi in m[:2]):
raise ValueError("Mapping min and max values cannot be NaN")
if m[0] >= m[1]:
raise ValueError(
"Mapping min value must be strictly less than max value:"
f" {m[0]}, {m[1]}"
)
return mappings
if not is_scalar(m[2]) and m[2] is not None:
raise ValueError(
"The new value in a mapping must be a scalar or None. "
f"Got {m[2]!r}."
)
starts.append(m[0])
ends.append(m[1])
news.append(m[2])
return starts, ends, news


def remap_range(raster, mapping, inclusivity="left"):
Expand All @@ -1026,7 +1038,8 @@ def remap_range(raster, mapping, inclusivity="left"):
mapping : 3-tuple of scalars or list of 3-tuples of scalars
A tuple or list of tuples containing ``(min, max, new_value)``
scalars. The mappiing(s) map values between the min and max to the
``new_value``. If `mapping` is a list and there are mappings that
``new_value``. If ``new_value`` is ``None``, the matching pixels will
be marked as null. If `mapping` is a list and there are mappings that
conflict or overlap, earlier mappings take precedence. `inclusivity`
determines which sides of the range are inclusive and exclusive.
inclusivity : str, optional
Expand All @@ -1049,19 +1062,29 @@ def remap_range(raster, mapping, inclusivity="left"):
"""
raster = get_raster(raster)
mappings = _normalize_mappings(mapping)
map_starts, map_ends, map_news = _normalize_mappings(mapping)
if not is_str(inclusivity):
raise TypeError(
f"inclusivity must be a str. Got type: {type(inclusivity)}"
)
inc_map = dict(zip(("left", "right", "both", "none"), range(4)))
if inclusivity not in inc_map:
raise ValueError(f"Invalid inclusivity value. Got: {inclusivity!r}")
mappings_common_dtype = get_common_dtype([m[-1] for m in mappings])
if not all(m is None for m in map_news):
mappings_common_dtype = get_common_dtype(
[m for m in map_news if m is not None]
)
else:
mappings_common_dtype = raster.dtype
out_dtype = np.promote_types(raster.dtype, mappings_common_dtype)
# numba doesn't understand f16 so use f32 and then downcast
f16_workaround = out_dtype == F16
mappings = np.atleast_2d(mappings)
if raster._masked:
nv = raster.null_value
else:
nv = get_default_null_value(out_dtype)
start_end_values = np.array([[s, e] for s, e in zip(map_starts, map_ends)])
new_values = np.array([v if v is not None else nv for v in map_news])

outrs = raster.copy()
if out_dtype != outrs.dtype:
Expand All @@ -1072,15 +1095,16 @@ def remap_range(raster, mapping, inclusivity="left"):
elif f16_workaround:
outrs = outrs.astype(F32)
data = outrs.data
func = partial(
_remap_values, mappings=mappings, inclusivity=inc_map[inclusivity]
)
outrs.xdata.data = data.map_blocks(
func,
_remap_values,
raster.mask,
start_end_values=start_end_values,
new_values=new_values,
inclusivity=inc_map[inclusivity],
dtype=data.dtype,
meta=np.array((), dtype=data.dtype),
)
outrs.xmask.data = outrs.data == nv
if f16_workaround:
outrs = outrs.astype(F16)
return outrs
Expand Down Expand Up @@ -1209,7 +1233,8 @@ class RemapFileParseError(Exception):
r"(?:[-+]?(?:(?:0|[1-9]\d*)(?:\.\d*)?|\.\d+)?)(?:[eE][-+]?\d+)?"
)
_REMAPPING_LINE_PATTERN = re.compile(
rf"^\s*(?P<from>{_INT_OR_FLOAT_RE})\s*:\s*(?P<to>{_INT_OR_FLOAT_RE})\s*$"
rf"^\s*(?P<from>{_INT_OR_FLOAT_RE})\s*:"
rf"\s*(?P<to>{_INT_OR_FLOAT_RE}|NoData)\s*$"
)


Expand All @@ -1235,7 +1260,7 @@ def _parse_ascii_remap_file(fd):
from_str = m.group("from")
to_str = m.group("to")
k = _str_to_float_or_int(from_str)
v = _str_to_float_or_int(to_str)
v = _str_to_float_or_int(to_str) if to_str != "NoData" else None
if k in mapping:
raise ValueError(f"Found duplicate mapping: '{k}:{v}'.")
mapping[k] = v
Expand Down Expand Up @@ -1267,12 +1292,14 @@ def reclassify(raster, remapping, unmapped_to_null=False):
The input raster to reclassify. Can be a path string or Raster object.
remapping : str, dict
Can be either a ``dict`` or a path string. If a ``dict`` is provided,
the keys will be reclassified to the corresponding values. If a path
string, it is treated as an ASCII remap file where each line looks like
``a:b`` and ``a`` and ``b`` are integers. The output values of the
mapping can cause type promotion. If the input raster has integer data
and one of the outputs in the mapping is a float, the result will be a
float raster.
the keys will be reclassified to the corresponding values. It is
possible to map values to the null value by providing ``None`` in the
mapping. If a path string, it is treated as an ASCII remap file where
each line looks like ``a:b`` and ``a`` and ``b`` are scalars. ``b`` can
also be "NoData". This indicates that ``a`` will be mapped to the null
value. The output values of the mapping can cause type promotion. If
the input raster has integer data and one of the outputs in the mapping
is a float, the result will be a float raster.
unmapped_to_null : bool, optional
If ``True``, values not included in the mapping are instead mapped to
the null value. Default is ``False``.
Expand All @@ -1287,21 +1314,18 @@ def reclassify(raster, remapping, unmapped_to_null=False):
remapping = _get_remapping(remapping)

out_dtype = raster.dtype
mapping_out_values = list(remapping.values())
out_dtype = np.promote_types(
get_common_dtype(mapping_out_values), out_dtype
)
if unmapped_to_null:
if raster._masked:
nv = raster.null_value
else:
nv = get_default_null_value(out_dtype)
mapping_out_values = [v for v in remapping.values() if v is not None]
if len(mapping_out_values):
out_dtype = np.promote_types(
get_common_dtype(mapping_out_values), out_dtype
)
if raster._masked:
nv = raster.null_value
else:
# 0 is a placeholder value. If input raster is not masked and
# unmapped_to_null is False, the null value will not be used in the
# dask function. nv still needs to be passed to dask so it needs to be
# a value with valid type.
nv = raster.null_value if raster._masked else 0
nv = get_default_null_value(out_dtype)
for k in remapping:
if remapping[k] is None:
remapping[k] = nv

mapping_from = np.array(list(remapping))
mapping_to = np.array(list(remapping.values())).astype(out_dtype)
Expand Down
89 changes: 76 additions & 13 deletions raster_tools/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from raster_tools.dask_utils import dask_nanmax, dask_nanmin
from raster_tools.dtypes import (
BOOL,
DTYPE_INPUT_TO_DTYPE,
F16,
F32,
Expand Down Expand Up @@ -1093,6 +1094,60 @@ def set_null_value(self, value):
xrs = xrs.rio.write_nodata(value)
return Raster(make_raster_ds(xrs, mask), _fast_path=True).burn_mask()

def set_null(self, mask_raster):
"""Update the raster's null pixels using the provided mask
Parameters
----------
mask_raster : str, Raster
Raster or path to a raster that is used to update the masked out
pixels. This raster updates the mask. It does not replace the mask.
Pixels that were already marked as null stay null and pixels that
are marked as null in `mask_raster` become marked as null in the
resulting raster. This is a logical "OR" operation. `mask_raster`
must have data type of boolean, int8, or uint8. `mask_raster` must
have either 1 band or the same number of bands as the raster it is
being applied to. A single band `mask_raster` is broadcast across
all bands of the raster being modified.
Returns
-------
Raster
The resulting raster with updated mask.
"""
mask_raster = get_raster(mask_raster)
if mask_raster.nbands > 1 and mask_raster.nbands != self.nbands:
raise ValueError(
"The number of bands in mask_raster must be 1 or match"
f" this raster. Got {mask_raster.nbands}"
)
if mask_raster.shape[1:] != self.shape[1:]:
raise ValueError(
"x and y dims for mask_raster do not match this raster."
f" {mask_raster.shape[1:]} vs {self.shape[1:]}"
)
dtype = mask_raster.dtype
if dtype not in {BOOL, I8, U8}:
raise TypeError("mask_raster must be boolean, int8, or uint8")
elif not is_bool(dtype):
mask_raster = mask_raster.astype(bool)

out_raster = self.copy()
new_mask_data = out_raster._ds.mask.data
# Rely on numpy broadcasting when applying the new mask data
if mask_raster._masked:
new_mask_data |= mask_raster.data & (~mask_raster.mask)
else:
new_mask_data |= mask_raster.data
out_raster._ds.mask.data = new_mask_data
if not self._masked:
out_raster._ds["raster"] = out_raster._ds.raster.rio.write_nodata(
get_default_null_value(self.dtype)
)
# Burn mask to set null values in newly masked regions
return out_raster.burn_mask()

def burn_mask(self):
"""Fill null-masked cells with null value.
Expand All @@ -1101,15 +1156,17 @@ def burn_mask(self):
promoting to fit the null value.
"""
if not self._masked:
return self
return self.copy()
nv = self.null_value
if is_bool(self.dtype):
# Sanity check to make sure that boolean rasters get a boolean null
# value
nv = get_default_null_value(self.dtype)
# call write_nodata because xr.where drops the nodata info
xrs = xr.where(self._ds.mask, nv, self._ds.raster).rio.write_nodata(nv)
if self.crs is not None:
xrs = xrs.rio.write_crs(self.crs)
return Raster(make_raster_ds(xrs, self._ds.mask), _fast_path=True)
out_raster = self.copy()
# Work with .data to avoid dropping attributes caused by using xarray
# and rioxarrays' APIs
out_raster._ds.raster.data = da.where(self.mask, nv, self.data)
return out_raster

def to_null_mask(self):
"""
Expand Down Expand Up @@ -1187,10 +1244,11 @@ def remap_range(self, mapping, inclusivity="left"):
mapping : 3-tuple of scalars or list of 3-tuples of scalars
A tuple or list of tuples containing ``(min, max, new_value)``
scalars. The mappiing(s) map values between the min and max to the
``new_value``. If `mapping` is a list and there are mappings that
conflict or overlap, earlier mappings take precedence.
`inclusivity` determines which sides of the range are inclusive and
exclusive.
``new_value``. If ``new_value`` is ``None``, the matching pixels
will be marked as null. If `mapping` is a list and there are
mappings that conflict or overlap, earlier mappings take
precedence. `inclusivity` determines which sides of the range are
inclusive and exclusive.
inclusivity : str, optional
Determines whether to be inclusive or exclusive on either end of
the range. Default is `'left'`.
Expand Down Expand Up @@ -1267,9 +1325,14 @@ def reclassify(self, remapping, unmapped_to_null=False):
remapping : str, dict
Can be either a ``dict`` or a path string. If a ``dict`` is
provided, the keys will be reclassified to the corresponding
values. If a path string, it is treated as an ASCII remap file
where each line looks like ``a:b`` and ``a`` and ``b`` are
integers. All remap values (both from and to) must be integers.
values. It is possible to map values to the null value by providing
``None`` in the mapping. If a path string, it is treated as an
ASCII remap file where each line looks like ``a:b`` and ``a`` and
``b`` are scalars. ``b`` can also be "NoData". This indicates that
``a`` will be mapped to the null value. The output values of the
mapping can cause type promotion. If the input raster has integer
data and one of the outputs in the mapping is a float, the result
will be a float raster.
unmapped_to_null : bool, optional
If ``True``, values not included in the mapping are instead mapped
to the null value. Default is ``False``.
Expand Down
Loading

0 comments on commit 6f8bba2

Please sign in to comment.