diff --git a/.github/workflows/install-test-conda-forge.yml b/.github/workflows/install-test-conda-forge.yml index 92f8467..945b603 100644 --- a/.github/workflows/install-test-conda-forge.yml +++ b/.github/workflows/install-test-conda-forge.yml @@ -30,7 +30,7 @@ jobs: - name: Install package run: | conda info - conda install homonim>=0.4.0 + conda install homonim>=0.4.1 conda list - name: Run CLI fusion test diff --git a/homonim/__init__.py b/homonim/__init__.py index f0723ff..a727d1d 100644 --- a/homonim/__init__.py +++ b/homonim/__init__.py @@ -19,18 +19,13 @@ import os import pathlib import logging -import warnings -from rasterio.errors import NotGeoreferencedWarning from homonim.compare import RasterCompare from homonim.enums import Model, ProcCrs from homonim.fuse import RasterFuse from homonim.kernel_model import KernelModel from homonim.stats import ParamStats -# suppress NotGeoreferencedWarning which rasterio can raise incorrectly -warnings.simplefilter('ignore', category=NotGeoreferencedWarning) - # Add a NullHandler to the package logger to hide logs by default. Applications can then add # their own handler(s). log = logging.getLogger(__name__) diff --git a/homonim/cli.py b/homonim/cli.py index 77369e2..1c1d071 100644 --- a/homonim/cli.py +++ b/homonim/cli.py @@ -32,6 +32,7 @@ import yaml from click.core import ParameterSource from rasterio.warp import SUPPORTED_RESAMPLING +from rasterio.errors import NotGeoreferencedWarning from homonim import utils, version, RasterFuse, RasterCompare, ParamStats, ProcCrs, Model from homonim.errors import ImageFormatError @@ -145,6 +146,9 @@ def showwarning(message, category, filename, lineno, file=None, line=None): logger = logging.getLogger(module_name) logger.warning(str(message)) + # suppress NotGeoreferencedWarning which rasterio can raise incorrectly + warnings.simplefilter('ignore', category=NotGeoreferencedWarning) + # redirect orthority warnings to module logger orig_show_warning = warnings.showwarning warnings.showwarning = showwarning diff --git a/homonim/kernel_model.py b/homonim/kernel_model.py index eb0b374..a07cf79 100644 --- a/homonim/kernel_model.py +++ b/homonim/kernel_model.py @@ -188,7 +188,7 @@ def _r2_array( # The above can be expanded and expressed in terms of cv.boxFilter kernel sums as: ss_res_array = ( ((param_array[0] ** 2) * src2_sum) + - (2 * np.product(param_array[:2], axis=0) * src_sum) - + (2 * np.prod(param_array[:2], axis=0) * src_sum) - (2 * param_array[0] * src_ref_sum) - (2 * param_array[1] * ref_sum) + ref2_sum + (mask_sum * (param_array[1] ** 2)) diff --git a/homonim/matched_pair.py b/homonim/matched_pair.py index f663080..045dfc6 100644 --- a/homonim/matched_pair.py +++ b/homonim/matched_pair.py @@ -248,6 +248,7 @@ def _match_pair_bands( # absolute & relative distance matrix between src and ref center wavelengths abs_dist = np.abs(src_wavelengths[:, np.newaxis] - ref_wavelengths[np.newaxis, :]) rel_dist = abs_dist / src_wavelengths[:, np.newaxis] + def greedy_match(dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Greedy matching of src to ref bands based on the provided center wavelength distance matrix, @@ -257,30 +258,29 @@ def greedy_match(dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ # match_idx[i] is the index of the ref band that matches with the ith src band match_idx = np.array([np.nan] * dist.shape[0]) - match_dist = np.array([np.nan] * dist.shape[0]) # distances corresponding to the above matches - - # suppress runtime warning on all-Nan slice, as it is expected in normal operation - with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=RuntimeWarning) - # repeat until all src or ref bands have been matched - while not all(np.isnan(np.nanmin(dist, axis=1))) or not all(np.isnan(np.nanmin(dist, axis=0))): - # find the row with the smallest distance in it - min_dist = np.nanmin(dist, axis=1) - min_dist_row_idx = np.nanargmin(min_dist) - min_dist_row = dist[min_dist_row_idx, :] - # store match idx and distance for this row - match_idx[min_dist_row_idx] = np.nanargmin(min_dist_row) - match_dist[min_dist_row_idx] = min_dist[min_dist_row_idx] - # set the matched row and col to nan, so that it is not used again - dist[:, int(match_idx[min_dist_row_idx])] = np.nan - dist[min_dist_row_idx, :] = np.nan + match_dist = np.array([np.nan] * dist.shape[0]) # distances corresponding to the above matches + + # use masked array rather than nan pass-through to avoid all-nan slice warnings + dist = np.ma.array(dist, mask=np.isnan(dist)) + + # repeat until all src or ref bands have been matched + while not dist.mask.all(): + # find the row with the smallest distance in it + min_dist = dist.min(axis=1) + min_dist_row_idx = np.ma.argmin(min_dist) + min_dist_row = dist[min_dist_row_idx, :] + # store match idx and distance for this row + match_idx[min_dist_row_idx] = np.ma.argmin(min_dist_row) + match_dist[min_dist_row_idx] = min_dist[min_dist_row_idx] + # set the matched row and col to nan, so that it is not used again + dist[:, int(match_idx[min_dist_row_idx])] = np.ma.masked + dist[min_dist_row_idx, :] = np.ma.masked return match_dist, match_idx match_dist, match_idx = greedy_match(rel_dist) - # if any of the matched distances are greater than a threshold, raise an informative error, - # or log a warning, depending on `self._force` + # if any of the matched distances are greater than a threshold, raise an informative error if any(match_dist > MatchedPairReader._max_rel_wavelength_diff): err_idx = match_dist > MatchedPairReader._max_rel_wavelength_diff src_err_band_names = list(src_band_names[err_idx]) diff --git a/homonim/raster_array.py b/homonim/raster_array.py index 678fb66..0afdf61 100644 --- a/homonim/raster_array.py +++ b/homonim/raster_array.py @@ -478,13 +478,6 @@ def to_rio_dataset( f'The length of indexes ({len(indexes)}) exceeds the number of bands in the ' f'RasterArray ({self.count})' ) - if rio_dataset.nodata is not None and ( - self.nodata is None or not utils.nan_equals(self.nodata, rio_dataset.nodata) - ): - warnings.warn( - f"The dataset nodata: {rio_dataset.nodata} does not match the RasterArray nodata: {self.nodata}", - category=ImageFormatWarning - ) if window is None: # a window defining the region in the dataset corresponding to the RasterArray extents diff --git a/homonim/raster_pair.py b/homonim/raster_pair.py index e74d2ce..fd70c29 100644 --- a/homonim/raster_pair.py +++ b/homonim/raster_pair.py @@ -230,8 +230,8 @@ def _auto_block_shape(self, max_block_mem: float = np.inf) -> Tuple[int, int]: proc_win = self._ref_win if self.proc_crs == ProcCrs.ref else self._src_win # adjust max_block_mem to represent the size of a block in the highest resolution image, but scaled to the # equivalent in proc_crs. - src_pix_area = np.product(np.abs(self._src_im.res)) - ref_pix_area = np.product(np.abs(self._ref_im.res)) + src_pix_area = np.prod(np.abs(self._src_im.res)) + ref_pix_area = np.prod(np.abs(self._ref_im.res)) if self.proc_crs == ProcCrs.ref: mem_scale = src_pix_area / ref_pix_area if ref_pix_area > src_pix_area else 1. elif self.proc_crs == ProcCrs.src: @@ -247,7 +247,7 @@ def _auto_block_shape(self, max_block_mem: float = np.inf) -> Tuple[int, int]: block_shape = np.array((proc_win.height, proc_win.width)).astype('float') # keep halving the block_shape along the longest dimension until it satisfies max_block_mem - while (np.product(block_shape) * dtype_size) > max_block_mem: + while (np.prod(block_shape) * dtype_size) > max_block_mem: div_dim = np.argmax(block_shape) block_shape[div_dim] /= 2 diff --git a/homonim/utils.py b/homonim/utils.py index e6602d1..4179dd6 100644 --- a/homonim/utils.py +++ b/homonim/utils.py @@ -121,9 +121,9 @@ def validate_kernel_shape(kernel_shape: Tuple[int, int], model: Model = Model.ga if not np.all(np.mod(kernel_shape, 2) == 1): raise ValueError('`kernel_shape` must be odd in both dimensions.') if model == Model.gain_offset: - if np.product(kernel_shape) < 2: + if np.prod(kernel_shape) < 2: raise ValueError('`kernel_shape` area should contain at least 2 elements for the gain-offset model.') - elif np.product(kernel_shape) < 25: + elif np.prod(kernel_shape) < 25: warnings.warn( 'A `kernel_shape` of at least 25 elements is recommended for the gain-offset model.', category=ConfigWarning @@ -155,6 +155,8 @@ def overlap_for_kernel(kernel_shape: Tuple[int, int]) -> Tuple[int, int]: def validate_threads(threads: int) -> int: """ Parse number of threads parameter. """ + # TODO: Memory increases ~linearly with number of threads, but does processing speed? The bottleneck is often + # file IO & I am not sure >2 threads as a default is justified. _cpu_count = cpu_count() threads = _cpu_count if threads == 0 else threads if threads > _cpu_count: diff --git a/homonim/version.py b/homonim/version.py index abeeedb..f0ede3d 100644 --- a/homonim/version.py +++ b/homonim/version.py @@ -1 +1 @@ -__version__ = '0.4.0' +__version__ = '0.4.1' diff --git a/tests/test_matched_pair.py b/tests/test_matched_pair.py index 05bc740..83cf6ae 100644 --- a/tests/test_matched_pair.py +++ b/tests/test_matched_pair.py @@ -18,16 +18,16 @@ """ from pathlib import Path -from typing import Tuple, List, Dict +from typing import Tuple, List +import warnings import numpy as np import pytest import rasterio as rio -from rasterio import Affine -from rasterio.windows import Window from homonim import utils from homonim.matched_pair import MatchedPairReader +from homonim.errors import HomonimWarning @pytest.mark.parametrize(['file', 'bands', 'exp_bands', 'exp_band_names', 'exp_wavelengths'], [ @@ -104,9 +104,15 @@ def test_match( """ Test matching of different source and reference files. """ src_file: Path = request.getfixturevalue(src_file) ref_file: Path = request.getfixturevalue(ref_file) - with MatchedPairReader(src_file, ref_file, src_bands=src_bands, ref_bands=ref_bands, force=force) as matched_pair: - assert all(np.array(matched_pair.src_bands) == exp_src_bands) - assert all(np.array(matched_pair.ref_bands) == exp_ref_bands) + + with warnings.catch_warnings(): + # test there are no all-nan warnings by turning them RuntimeWarning into an error, while allowing + # HomonimWarning which sub-classes RuntimeWarning + warnings.simplefilter("error", category=RuntimeWarning) + warnings.simplefilter("default", category=HomonimWarning) + with MatchedPairReader(src_file, ref_file, src_bands=src_bands, ref_bands=ref_bands, force=force) as matched_pair: + assert all(np.array(matched_pair.src_bands) == exp_src_bands) + assert all(np.array(matched_pair.ref_bands) == exp_ref_bands) def test_match_fewer_ref_bands_error(s2_ref_file, landsat_ref_file):