diff --git a/satpy/readers/generic_image.py b/satpy/readers/generic_image.py index c0e334302f..64f396de3d 100644 --- a/satpy/readers/generic_image.py +++ b/satpy/readers/generic_image.py @@ -54,11 +54,15 @@ class GenericImageFileHandler(BaseFileHandler): """Handle reading of generic image files.""" - def __init__(self, filename, filename_info, filetype_info): + def __init__(self, filename, filename_info, filetype_info, set_fill_value=None, nodata_handling=None): """Initialize filehandler.""" super(GenericImageFileHandler, self).__init__( filename, filename_info, filetype_info) self.finfo = filename_info + self.set_fill_value = set_fill_value + if self.set_fill_value is not None and not isinstance(self.set_fill_value, (list, tuple)): + self.set_fill_value = [self.set_fill_value] + self.nodata_handling = nodata_handling try: self.finfo["end_time"] = self.finfo["start_time"] except KeyError: @@ -82,6 +86,7 @@ def read(self): # however, error is not explicit enough (see https://github.com/pydata/xarray/issues/7831) data = xr.open_dataset(self.finfo["filename"], engine="rasterio", chunks={"band": 1, "y": CHUNK_SIZE, "x": CHUNK_SIZE}, mask_and_scale=False)["band_data"] + if hasattr(dataset, "nodatavals"): # The nodata values for the raster bands # copied from https://github.com/pydata/xarray/blob/v2023.03.0/xarray/backends/rasterio_.py#L322-L326 @@ -89,6 +94,8 @@ def read(self): np.nan if nodataval is None else nodataval for nodataval in dataset.nodatavals ) data.attrs["nodatavals"] = nodatavals + if self.set_fill_value: + data.attrs["nodatavals"] = self.set_fill_value attrs = data.attrs.copy() @@ -123,10 +130,12 @@ def get_dataset(self, key, info): ds_name = self.dataset_name if self.dataset_name else key["name"] logger.debug("Reading '%s.'", ds_name) data = self.file_content[ds_name] + if self.nodata_handling is not None: + info["nodata_handling"] = self.nodata_handling # Mask data if necessary try: - data = _mask_image_data(data, info) + data = _mask_image_data(data, info, self.set_fill_value) except ValueError as err: logger.warning(err) @@ -135,7 +144,7 @@ def get_dataset(self, key, info): return data -def _mask_image_data(data, info): +def _mask_image_data(data, info, set_fill_value): """Mask image data if necessary. Masking is done if alpha channel is present or @@ -143,7 +152,7 @@ def _mask_image_data(data, info): In the latter case even integer data is converted to float32 and masked with np.nan. """ - if data.bands.size in (2, 4): + if data.bands.size in (2, 4) and not set_fill_value: if not np.issubdtype(data.dtype, np.integer): raise ValueError("Only integer datatypes can be used as a mask.") mask = data.data[-1, :, :] == np.iinfo(data.dtype).min @@ -174,4 +183,6 @@ def _handle_nodatavals(data, nodata_handling): if np.issubdtype(data.dtype, np.integer): fill_value = int(fill_value) data.attrs["_FillValue"] = fill_value + if "A" in data.bands: + data = data.drop_sel(bands="A") return data diff --git a/satpy/tests/reader_tests/test_generic_image.py b/satpy/tests/reader_tests/test_generic_image.py index 0d5d647420..8e68c592a1 100644 --- a/satpy/tests/reader_tests/test_generic_image.py +++ b/satpy/tests/reader_tests/test_generic_image.py @@ -73,7 +73,7 @@ def test_image_l(tmp_path, random_image_channel_l): attrs={"name": "test_l", "start_time": DATA_DATE}) dset["bands"] = ["L"] fname = tmp_path / "test_l.png" - _save_image(dset, fname, "simple_image") + _save_image(dset, fname, "simple_image", 255) return fname @@ -162,7 +162,7 @@ def test_png_scene_l_mode(test_image_l): with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): scn = Scene(reader="generic_image", filenames=[test_image_l]) scn.load(["image"]) - _assert_image_common(scn, 1, None, None, np.float32) + _assert_image_common(scn, 1, None, None, np.uint8) assert "area" not in scn["image"].attrs @@ -182,13 +182,23 @@ def test_png_scene_la_mode(test_image_la): """Test reading a PNG image with LA mode via satpy.Scene().""" with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): scn = Scene(reader="generic_image", filenames=[test_image_la]) - scn.load(["image"]) + scn.load(["image"], nodata_handling="fill_value") data = da.compute(scn["image"].data) assert np.sum(np.isnan(data)) == 100 assert "area" not in scn["image"].attrs _assert_image_common(scn, 1, DATA_DATE, DATA_DATE, np.float32) +def test_png_scene_la_mode_set_fill(test_image_la): + """Test reading a PNG image with L mode via satpy.Scene() setting input fill value.""" + with pytest.warns(NotGeoreferencedWarning, match=r"Dataset has no geotransform"): + scn = Scene(reader="generic_image", filenames=[test_image_la], + reader_kwargs={"set_fill_value": 255, "nodata_handling": "fill_value"}) + scn.load(["image"]) + _assert_image_common(scn, 1, DATA_DATE, DATA_DATE, np.uint8) + assert "area" not in scn["image"].attrs + + def test_geotiff_scene_rgb(test_image_rgb): """Test reading geotiff image in RGB mode via satpy.Scene().""" scn = Scene(reader="generic_image", filenames=[test_image_rgb]) @@ -205,6 +215,15 @@ def test_geotiff_scene_rgba(test_image_rgba): assert scn["image"].area == AREA_DEFINITION +def test_png_scene_rgba_mode_set_fill(test_image_rgba): + """Test reading an image in RGBA mode via satpy.Scene() setting input fill value.""" + scn = Scene(reader="generic_image", filenames=[test_image_rgba], + reader_kwargs={"set_fill_value": 255, "nodata_handling": "fill_value"}) + scn.load(["image"]) + _assert_image_common(scn, 3, None, None, np.uint8) + assert scn["image"].area == AREA_DEFINITION + + def test_geotiff_scene_nan_fill_value(test_image_l_nan_fill_value): """Test reading geotiff image with fill value set via satpy.Scene().""" scn = Scene(reader="generic_image", filenames=[test_image_l_nan_fill_value]) @@ -252,6 +271,8 @@ def __init__(self, filename, filename_info, filetype_info, file_content, **kwarg super(GenericImageFileHandler, self).__init__(filename, filename_info, filetype_info) self.file_content = file_content self.dataset_name = None + self.nodata_handling = kwargs.pop("nodata_handling", None) + self.set_fill_value = kwargs.pop("set_fill_value", None) self.file_content.update(kwargs)