diff --git a/docs/reference/raster.rst b/docs/reference/raster.rst index 6f1057d..d807547 100644 --- a/docs/reference/raster.rst +++ b/docs/reference/raster.rst @@ -122,6 +122,7 @@ Null Data and Remapping Raster.replace_null Raster.remap_range Raster.set_null_value + Raster.set_null Raster.to_null_mask Raster.where diff --git a/raster_tools/general.py b/raster_tools/general.py index 9b11931..8271cd1 100644 --- a/raster_tools/general.py +++ b/raster_tools/general.py @@ -949,10 +949,10 @@ def band_concat(rasters): @nb.jit(nopython=True, nogil=True) -def _remap_values(x, mask, mappings, inclusivity): +def _remap_values(x, mask, start_end_values, new_values, inclusivity): outx = np.zeros_like(x) bands, rows, columns = x.shape - rngs = mappings.shape[0] + rngs = start_end_values.shape[0] for bnd in range(bands): for rw in range(rows): for cl in range(columns): @@ -961,7 +961,8 @@ def _remap_values(x, mask, mappings, inclusivity): vl = x[bnd, rw, cl] remap = False for imap in range(rngs): - left, right, new = mappings[imap] + left, right = start_end_values[imap] + new = new_values[imap] if inclusivity == 0: remap = left <= vl < right elif inclusivity == 1: @@ -986,7 +987,7 @@ def _normalize_mappings(mappings): ) if not len(mappings): raise ValueError("No mappings provided") - if len(mappings) and is_scalar(mappings[0]): + if is_scalar(mappings[0]): mappings = [mappings] try: mappings = [list(m) for m in mappings] @@ -995,14 +996,17 @@ def _normalize_mappings(mappings): "Mappings must be either single 3-tuple or list of 3-tuples of " "scalars" ) from None + starts = [] + ends = [] + news = [] for m in mappings: if len(m) != 3: raise ValueError( "Mappings must be either single 3-tuple or list of 3-tuples of" " scalars" ) - if not all(is_scalar(mi) for mi in m): - raise TypeError("Mappings values must be scalars") + if not all(is_scalar(mi) for mi in m[:2]): + raise TypeError("Mapping min and max values must be scalars") if any(np.isnan(mi) for mi in m[:2]): raise ValueError("Mapping min and max values cannot be NaN") if m[0] >= m[1]: @@ -1010,7 +1014,15 @@ def _normalize_mappings(mappings): "Mapping min value must be strictly less than max value:" f" {m[0]}, {m[1]}" ) - return mappings + if not is_scalar(m[2]) and m[2] is not None: + raise ValueError( + "The new value in a mapping must be a scalar or None. " + f"Got {m[2]!r}." + ) + starts.append(m[0]) + ends.append(m[1]) + news.append(m[2]) + return starts, ends, news def remap_range(raster, mapping, inclusivity="left"): @@ -1026,7 +1038,8 @@ def remap_range(raster, mapping, inclusivity="left"): mapping : 3-tuple of scalars or list of 3-tuples of scalars A tuple or list of tuples containing ``(min, max, new_value)`` scalars. The mappiing(s) map values between the min and max to the - ``new_value``. If `mapping` is a list and there are mappings that + ``new_value``. If ``new_value`` is ``None``, the matching pixels will + be marked as null. If `mapping` is a list and there are mappings that conflict or overlap, earlier mappings take precedence. `inclusivity` determines which sides of the range are inclusive and exclusive. inclusivity : str, optional @@ -1049,7 +1062,7 @@ def remap_range(raster, mapping, inclusivity="left"): """ raster = get_raster(raster) - mappings = _normalize_mappings(mapping) + map_starts, map_ends, map_news = _normalize_mappings(mapping) if not is_str(inclusivity): raise TypeError( f"inclusivity must be a str. Got type: {type(inclusivity)}" @@ -1057,11 +1070,21 @@ def remap_range(raster, mapping, inclusivity="left"): inc_map = dict(zip(("left", "right", "both", "none"), range(4))) if inclusivity not in inc_map: raise ValueError(f"Invalid inclusivity value. Got: {inclusivity!r}") - mappings_common_dtype = get_common_dtype([m[-1] for m in mappings]) + if not all(m is None for m in map_news): + mappings_common_dtype = get_common_dtype( + [m for m in map_news if m is not None] + ) + else: + mappings_common_dtype = raster.dtype out_dtype = np.promote_types(raster.dtype, mappings_common_dtype) # numba doesn't understand f16 so use f32 and then downcast f16_workaround = out_dtype == F16 - mappings = np.atleast_2d(mappings) + if raster._masked: + nv = raster.null_value + else: + nv = get_default_null_value(out_dtype) + start_end_values = np.array([[s, e] for s, e in zip(map_starts, map_ends)]) + new_values = np.array([v if v is not None else nv for v in map_news]) outrs = raster.copy() if out_dtype != outrs.dtype: @@ -1072,15 +1095,16 @@ def remap_range(raster, mapping, inclusivity="left"): elif f16_workaround: outrs = outrs.astype(F32) data = outrs.data - func = partial( - _remap_values, mappings=mappings, inclusivity=inc_map[inclusivity] - ) outrs.xdata.data = data.map_blocks( - func, + _remap_values, raster.mask, + start_end_values=start_end_values, + new_values=new_values, + inclusivity=inc_map[inclusivity], dtype=data.dtype, meta=np.array((), dtype=data.dtype), ) + outrs.xmask.data = outrs.data == nv if f16_workaround: outrs = outrs.astype(F16) return outrs @@ -1209,7 +1233,8 @@ class RemapFileParseError(Exception): r"(?:[-+]?(?:(?:0|[1-9]\d*)(?:\.\d*)?|\.\d+)?)(?:[eE][-+]?\d+)?" ) _REMAPPING_LINE_PATTERN = re.compile( - rf"^\s*(?P{_INT_OR_FLOAT_RE})\s*:\s*(?P{_INT_OR_FLOAT_RE})\s*$" + rf"^\s*(?P{_INT_OR_FLOAT_RE})\s*:" + rf"\s*(?P{_INT_OR_FLOAT_RE}|NoData)\s*$" ) @@ -1235,7 +1260,7 @@ def _parse_ascii_remap_file(fd): from_str = m.group("from") to_str = m.group("to") k = _str_to_float_or_int(from_str) - v = _str_to_float_or_int(to_str) + v = _str_to_float_or_int(to_str) if to_str != "NoData" else None if k in mapping: raise ValueError(f"Found duplicate mapping: '{k}:{v}'.") mapping[k] = v @@ -1267,12 +1292,14 @@ def reclassify(raster, remapping, unmapped_to_null=False): The input raster to reclassify. Can be a path string or Raster object. remapping : str, dict Can be either a ``dict`` or a path string. If a ``dict`` is provided, - the keys will be reclassified to the corresponding values. If a path - string, it is treated as an ASCII remap file where each line looks like - ``a:b`` and ``a`` and ``b`` are integers. The output values of the - mapping can cause type promotion. If the input raster has integer data - and one of the outputs in the mapping is a float, the result will be a - float raster. + the keys will be reclassified to the corresponding values. It is + possible to map values to the null value by providing ``None`` in the + mapping. If a path string, it is treated as an ASCII remap file where + each line looks like ``a:b`` and ``a`` and ``b`` are scalars. ``b`` can + also be "NoData". This indicates that ``a`` will be mapped to the null + value. The output values of the mapping can cause type promotion. If + the input raster has integer data and one of the outputs in the mapping + is a float, the result will be a float raster. unmapped_to_null : bool, optional If ``True``, values not included in the mapping are instead mapped to the null value. Default is ``False``. @@ -1287,21 +1314,18 @@ def reclassify(raster, remapping, unmapped_to_null=False): remapping = _get_remapping(remapping) out_dtype = raster.dtype - mapping_out_values = list(remapping.values()) - out_dtype = np.promote_types( - get_common_dtype(mapping_out_values), out_dtype - ) - if unmapped_to_null: - if raster._masked: - nv = raster.null_value - else: - nv = get_default_null_value(out_dtype) + mapping_out_values = [v for v in remapping.values() if v is not None] + if len(mapping_out_values): + out_dtype = np.promote_types( + get_common_dtype(mapping_out_values), out_dtype + ) + if raster._masked: + nv = raster.null_value else: - # 0 is a placeholder value. If input raster is not masked and - # unmapped_to_null is False, the null value will not be used in the - # dask function. nv still needs to be passed to dask so it needs to be - # a value with valid type. - nv = raster.null_value if raster._masked else 0 + nv = get_default_null_value(out_dtype) + for k in remapping: + if remapping[k] is None: + remapping[k] = nv mapping_from = np.array(list(remapping)) mapping_to = np.array(list(remapping.values())).astype(out_dtype) diff --git a/raster_tools/raster.py b/raster_tools/raster.py index 65b38d2..c471dcb 100644 --- a/raster_tools/raster.py +++ b/raster_tools/raster.py @@ -17,6 +17,7 @@ from raster_tools.dask_utils import dask_nanmax, dask_nanmin from raster_tools.dtypes import ( + BOOL, DTYPE_INPUT_TO_DTYPE, F16, F32, @@ -1093,6 +1094,60 @@ def set_null_value(self, value): xrs = xrs.rio.write_nodata(value) return Raster(make_raster_ds(xrs, mask), _fast_path=True).burn_mask() + def set_null(self, mask_raster): + """Update the raster's null pixels using the provided mask + + Parameters + ---------- + mask_raster : str, Raster + Raster or path to a raster that is used to update the masked out + pixels. This raster updates the mask. It does not replace the mask. + Pixels that were already marked as null stay null and pixels that + are marked as null in `mask_raster` become marked as null in the + resulting raster. This is a logical "OR" operation. `mask_raster` + must have data type of boolean, int8, or uint8. `mask_raster` must + have either 1 band or the same number of bands as the raster it is + being applied to. A single band `mask_raster` is broadcast across + all bands of the raster being modified. + + Returns + ------- + Raster + The resulting raster with updated mask. + + """ + mask_raster = get_raster(mask_raster) + if mask_raster.nbands > 1 and mask_raster.nbands != self.nbands: + raise ValueError( + "The number of bands in mask_raster must be 1 or match" + f" this raster. Got {mask_raster.nbands}" + ) + if mask_raster.shape[1:] != self.shape[1:]: + raise ValueError( + "x and y dims for mask_raster do not match this raster." + f" {mask_raster.shape[1:]} vs {self.shape[1:]}" + ) + dtype = mask_raster.dtype + if dtype not in {BOOL, I8, U8}: + raise TypeError("mask_raster must be boolean, int8, or uint8") + elif not is_bool(dtype): + mask_raster = mask_raster.astype(bool) + + out_raster = self.copy() + new_mask_data = out_raster._ds.mask.data + # Rely on numpy broadcasting when applying the new mask data + if mask_raster._masked: + new_mask_data |= mask_raster.data & (~mask_raster.mask) + else: + new_mask_data |= mask_raster.data + out_raster._ds.mask.data = new_mask_data + if not self._masked: + out_raster._ds["raster"] = out_raster._ds.raster.rio.write_nodata( + get_default_null_value(self.dtype) + ) + # Burn mask to set null values in newly masked regions + return out_raster.burn_mask() + def burn_mask(self): """Fill null-masked cells with null value. @@ -1101,15 +1156,17 @@ def burn_mask(self): promoting to fit the null value. """ if not self._masked: - return self + return self.copy() nv = self.null_value if is_bool(self.dtype): + # Sanity check to make sure that boolean rasters get a boolean null + # value nv = get_default_null_value(self.dtype) - # call write_nodata because xr.where drops the nodata info - xrs = xr.where(self._ds.mask, nv, self._ds.raster).rio.write_nodata(nv) - if self.crs is not None: - xrs = xrs.rio.write_crs(self.crs) - return Raster(make_raster_ds(xrs, self._ds.mask), _fast_path=True) + out_raster = self.copy() + # Work with .data to avoid dropping attributes caused by using xarray + # and rioxarrays' APIs + out_raster._ds.raster.data = da.where(self.mask, nv, self.data) + return out_raster def to_null_mask(self): """ @@ -1187,10 +1244,11 @@ def remap_range(self, mapping, inclusivity="left"): mapping : 3-tuple of scalars or list of 3-tuples of scalars A tuple or list of tuples containing ``(min, max, new_value)`` scalars. The mappiing(s) map values between the min and max to the - ``new_value``. If `mapping` is a list and there are mappings that - conflict or overlap, earlier mappings take precedence. - `inclusivity` determines which sides of the range are inclusive and - exclusive. + ``new_value``. If ``new_value`` is ``None``, the matching pixels + will be marked as null. If `mapping` is a list and there are + mappings that conflict or overlap, earlier mappings take + precedence. `inclusivity` determines which sides of the range are + inclusive and exclusive. inclusivity : str, optional Determines whether to be inclusive or exclusive on either end of the range. Default is `'left'`. @@ -1267,9 +1325,14 @@ def reclassify(self, remapping, unmapped_to_null=False): remapping : str, dict Can be either a ``dict`` or a path string. If a ``dict`` is provided, the keys will be reclassified to the corresponding - values. If a path string, it is treated as an ASCII remap file - where each line looks like ``a:b`` and ``a`` and ``b`` are - integers. All remap values (both from and to) must be integers. + values. It is possible to map values to the null value by providing + ``None`` in the mapping. If a path string, it is treated as an + ASCII remap file where each line looks like ``a:b`` and ``a`` and + ``b`` are scalars. ``b`` can also be "NoData". This indicates that + ``a`` will be mapped to the null value. The output values of the + mapping can cause type promotion. If the input raster has integer + data and one of the outputs in the mapping is a float, the result + will be a float raster. unmapped_to_null : bool, optional If ``True``, values not included in the mapping are instead mapped to the null value. Default is ``False``. diff --git a/tests/test_general.py b/tests/test_general.py index 559e9bc..a1d7db1 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -714,6 +714,34 @@ def test_remap_range_f16(): assert np.allclose(result, truth) +def test_remap_range_null_mapping(): + raster = testdata.raster.dem_small + mapping = (0, 1500, None) + truth_rast = xr.where(raster.xdata > 1500, raster.xdata, raster.null_value) + truth_mask = truth_rast == raster.null_value + + result = general.remap_range(raster, mapping) + assert_valid_raster(result) + assert result.dtype == raster.dtype + assert result.null_value == raster.null_value + assert np.allclose(result.xdata, truth_rast) + assert np.allclose(result.xmask, truth_mask) + + mapping = [(0, 1500, None), (1500, 2000, 1)] + truth_rast = xr.where(raster.xdata > 1500, raster.xdata, raster.null_value) + truth_rast = xr.where( + (truth_rast >= 1500) & (truth_rast < 2000), 1, truth_rast + ) + truth_mask = truth_rast == raster.null_value + + result = general.remap_range(raster, mapping) + assert_valid_raster(result) + assert result.dtype == raster.dtype + assert result.null_value == raster.null_value + assert np.allclose(result.xdata, truth_rast) + assert np.allclose(result.xmask, truth_mask) + + def test_remap_range_errors(): rs = testdata.raster.dem_small # TypeError if not scalars @@ -721,10 +749,6 @@ def test_remap_range_errors(): general.remap_range(rs, (None, 2, 4)) with pytest.raises(TypeError): general.remap_range(rs, (0, "2", 4)) - with pytest.raises(TypeError): - general.remap_range(rs, (0, 2, None)) - with pytest.raises(TypeError): - general.remap_range(rs, [(0, 2, 1), (2, 3, None)]) # ValueError if nan with pytest.raises(ValueError): general.remap_range(rs, (np.nan, 2, 4)) @@ -854,7 +878,7 @@ def test_where(cond, x, y): .set_null_value(99) .set_null_value(100) .astype("int16"), - {0: -1, 1: -2, 2: -3, 120: -120, 150: -150}, + {0: -1, 1: -2, 2: -3, 120: -120, 150: None}, np.dtype("int16"), ), # Check promotion from uint16 to int32 @@ -902,20 +926,26 @@ def test_reclassify(raster, mapping, unmapped_to_null, expected_out_dtype): ) for f, t in mapping.items(): mapped |= tdata == f + if t is None: + t = nv tdata[tdata == f] = t if unmapped_to_null: tdata[~mapped] = nv + tmask = tdata == nv result = general.reclassify(raster, mapping, unmapped_to_null) assert_valid_raster(result) assert_rasters_similar(raster, result) assert result.dtype == expected_out_dtype - if unmapped_to_null: + if ( + raster._masked + or unmapped_to_null + or any(v is None for v in mapping.values()) + ): assert result._masked assert result.null_value == nv - if raster._masked: - assert result._masked assert np.allclose(tdata, result.data.compute()) + assert np.allclose(tmask, result.mask.compute()) @pytest.mark.parametrize( @@ -931,6 +961,7 @@ def test_reclassify(raster, mapping, unmapped_to_null, expected_out_dtype): {1e23: 2, -1e-23: 2, 1.2e23: 3}, ), (io.StringIO("1:\t 2\n 3 : 4 \n"), {1: 2, 3: 4}), + (io.StringIO("1:NoData"), {1: None}), ], ) def test_reclassify_mapping_file_parsing(fd, expected_mapping): @@ -946,6 +977,7 @@ def test_reclassify_mapping_file_parsing(fd, expected_mapping): (io.StringIO("12"), general.RemapFileParseError), (io.StringIO("12 34"), general.RemapFileParseError), (io.StringIO("1:2\n1:3"), ValueError), + (io.StringIO("1:ND"), general.RemapFileParseError), ], ) def test_reclassify__parse_ascii_remap_file_errors(fd, error): diff --git a/tests/test_raster.py b/tests/test_raster.py index f5fc124..3350228 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -1431,6 +1431,9 @@ def test_copy(): assert rs is not copy assert rs._ds is not copy._ds assert rs._ds.equals(copy._ds) + assert copy._ds.attrs == rs._ds.attrs + assert copy._ds.raster.attrs == rs._ds.raster.attrs + assert copy._ds.mask.attrs == rs._ds.mask.attrs assert np.allclose(rs, copy) # make sure a deep copy has occurred copy._ds.raster.data[0, -1, -1] = 0 @@ -1506,6 +1509,46 @@ def test_replace_null(): rs.replace_null(None) +def test_set_null(): + raster = testdata.raster.dem_small + truth = raster.to_numpy() + truth_mask = truth < 1500 + truth[truth_mask] = raster.null_value + result = raster.set_null(raster < 1500) + assert_valid_raster(result) + assert result.null_value == raster.null_value + assert np.allclose(result, truth) + assert np.allclose(result.mask.compute(), truth_mask) + assert result._ds.raster.attrs == raster._ds.raster.attrs + + # Make sure broadcasting works + raster = band_concat([testdata.raster.dem_small] * 3) + assert raster.shape == (3, 100, 100) + mask = (testdata.raster.dem_small < 1500).to_numpy() + truth[:, mask[0]] = raster.null_value + result = raster.set_null(testdata.raster.dem_small < 1500) + assert_valid_raster(result) + assert result.null_value == raster.null_value + assert np.allclose(result, truth) + assert np.allclose(result.mask.compute(), truth_mask) + assert result._ds.raster.attrs == raster._ds.raster.attrs + + # Make sure that a null value is added if not already present + raster = testdata.raster.dem_small.set_null_value(None) + nv = get_default_null_value(raster.dtype) + truth = raster.to_numpy() + truth_mask = truth < 1500 + truth[truth_mask] = nv + result = raster.set_null(raster < 1500) + assert_valid_raster(result) + assert result.null_value == nv + assert np.allclose(result, truth) + assert np.allclose(result.mask.compute(), truth_mask) + attrs = raster._ds.raster.attrs + attrs["_FillValue"] = nv + assert result._ds.raster.attrs == attrs + + @pytest.mark.filterwarnings("ignore:The null value") def test_where(): rs = testdata.raster.dem_small @@ -1630,41 +1673,86 @@ def test_get_bands(): def test_burn_mask(): - x = arange_nd((1, 5, 5)) - rs = Raster(x) - rs._ds["raster"] = xr.where( - (rs._ds.raster >= 0) & (rs._ds.raster < 10), -999, rs._ds.raster + raster = ( + arange_raster((1, 5, 5)) + .set_crs("EPSG:3857") + .set_null_value(4) + .set_null_value(-1) ) - rs = rs.set_null_value(-999).set_crs("EPSG:3857") - data = rs.to_numpy() - assert rs.null_value == -999 - assert rs._masked - assert rs.crs == "EPSG:3857" - true_mask = data < 10 - true_state = data.copy() - true_state[true_mask] = -999 - assert np.allclose(true_mask, rs._ds.mask) - assert np.allclose(true_state, rs) - - rs._ds.raster.data = data - assert np.allclose(rs, data) - assert_valid_raster(rs.burn_mask()) - assert np.allclose(rs.burn_mask(), true_state) - assert rs.burn_mask().crs == rs.crs - - data = arange_nd((1, 5, 5)) - rs = Raster(data) - rs._ds["raster"] = xr.where( - (rs._ds.raster >= 20) & (rs._ds.raster < 26), 999, rs._ds.raster + raster._ds.mask.data[raster.data > 20] = True + truth = raster.copy() + truth._ds.raster.data = da.where(truth.mask, truth.null_value, truth.data) + # Confirm state + assert_valid_raster(truth) + assert truth.crs == 3857 + assert truth.null_value == -1 + assert np.allclose( + truth.mask.compute(), + np.array( + [ + [ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 1, 1, 1, 1], + ] + ] + ).astype(bool), ) - rs = rs.set_null_value(999).set_crs("EPSG:3857") - rs = rs > 15 - assert rs.dtype == np.dtype(bool) - nv = get_default_null_value(bool) - assert rs.null_value == nv - true_state = data > 15 - true_state = np.where(data >= 20, nv, true_state) - assert np.allclose(rs.burn_mask(), true_state) + # Confirm raster is invalid because the data does not match the mask + with pytest.raises(AssertionError): + assert_valid_raster(raster) + + result = raster.burn_mask() + assert_valid_raster(result) + assert_rasters_similar(result, truth) + assert result is not raster + assert result.null_value == truth.null_value + # Make sure nothing else changed on the raster + assert result.xdata.attrs == raster.xdata.attrs + assert np.allclose(result, truth) + assert np.allclose(result.mask.compute(), truth.mask.compute()) + + # Make sure a copy is returned if there is nothing to mask + raster = arange_raster((1, 5, 5)).set_crs("EPSG:3857") + assert raster.null_value is None + result = raster.burn_mask() + assert_valid_raster(result) + assert_rasters_similar(result, truth) + assert result is not raster + assert result.null_value is None + # Make sure nothing else changed on the raster + assert result.xdata.attrs == raster.xdata.attrs + assert np.allclose(result, raster) + assert np.allclose(result.mask.compute(), raster.mask.compute()) + + # Boolean rasters + raster = ( + arange_raster((1, 5, 5)) + .set_crs("EPSG:3857") + .set_null_value(0) + .set_null_value(-1) + ) + raster = raster > 20 + raster._ds.mask.data[raster.data <= 12] = True + assert raster.dtype == np.dtype(bool) + assert raster.null_value == get_default_null_value(bool) + truth = raster.copy() + truth._ds.raster.data = da.where( + truth.mask, get_default_null_value(bool), truth.data + ) + truth._ds["raster"] = truth.xdata.rio.write_nodata( + get_default_null_value(bool) + ) + result = raster.burn_mask() + assert_valid_raster(result) + assert_rasters_similar(result, truth) + assert result is not raster + # Make sure nothing else changed on the raster + assert result.xdata.attrs == raster.xdata.attrs + assert np.allclose(result, truth) + assert np.allclose(result.mask.compute(), truth.mask.compute()) @pytest.mark.parametrize("index", [(0, 0), (0, 1), (1, 0), (-1, -1), (1, -1)])