Skip to content

Commit

Permalink
Merge pull request #4 from leftfield-geospatial/fix_warnings
Browse files Browse the repository at this point in the history
Fix warnings
  • Loading branch information
dugalh authored May 30, 2024
2 parents 92c204d + 4ef01b3 commit 29b03d1
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/install-test-conda-forge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Install package
run: |
conda info
conda install homonim>=0.4.0
conda install homonim>=0.4.1
conda list
- name: Run CLI fusion test
Expand Down
5 changes: 0 additions & 5 deletions homonim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@
import os
import pathlib
import logging
import warnings

from rasterio.errors import NotGeoreferencedWarning
from homonim.compare import RasterCompare
from homonim.enums import Model, ProcCrs
from homonim.fuse import RasterFuse
from homonim.kernel_model import KernelModel
from homonim.stats import ParamStats

# suppress NotGeoreferencedWarning which rasterio can raise incorrectly
warnings.simplefilter('ignore', category=NotGeoreferencedWarning)

# Add a NullHandler to the package logger to hide logs by default. Applications can then add
# their own handler(s).
log = logging.getLogger(__name__)
Expand Down
4 changes: 4 additions & 0 deletions homonim/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import yaml
from click.core import ParameterSource
from rasterio.warp import SUPPORTED_RESAMPLING
from rasterio.errors import NotGeoreferencedWarning

from homonim import utils, version, RasterFuse, RasterCompare, ParamStats, ProcCrs, Model
from homonim.errors import ImageFormatError
Expand Down Expand Up @@ -145,6 +146,9 @@ def showwarning(message, category, filename, lineno, file=None, line=None):
logger = logging.getLogger(module_name)
logger.warning(str(message))

# suppress NotGeoreferencedWarning which rasterio can raise incorrectly
warnings.simplefilter('ignore', category=NotGeoreferencedWarning)

# redirect orthority warnings to module logger
orig_show_warning = warnings.showwarning
warnings.showwarning = showwarning
Expand Down
2 changes: 1 addition & 1 deletion homonim/kernel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _r2_array(
# The above can be expanded and expressed in terms of cv.boxFilter kernel sums as:
ss_res_array = (
((param_array[0] ** 2) * src2_sum) +
(2 * np.product(param_array[:2], axis=0) * src_sum) -
(2 * np.prod(param_array[:2], axis=0) * src_sum) -
(2 * param_array[0] * src_ref_sum) -
(2 * param_array[1] * ref_sum) +
ref2_sum + (mask_sum * (param_array[1] ** 2))
Expand Down
38 changes: 19 additions & 19 deletions homonim/matched_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _match_pair_bands(
# absolute & relative distance matrix between src and ref center wavelengths
abs_dist = np.abs(src_wavelengths[:, np.newaxis] - ref_wavelengths[np.newaxis, :])
rel_dist = abs_dist / src_wavelengths[:, np.newaxis]

def greedy_match(dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Greedy matching of src to ref bands based on the provided center wavelength distance matrix,
Expand All @@ -257,30 +258,29 @@ def greedy_match(dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
# match_idx[i] is the index of the ref band that matches with the ith src band
match_idx = np.array([np.nan] * dist.shape[0])
match_dist = np.array([np.nan] * dist.shape[0]) # distances corresponding to the above matches

# suppress runtime warning on all-Nan slice, as it is expected in normal operation
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
# repeat until all src or ref bands have been matched
while not all(np.isnan(np.nanmin(dist, axis=1))) or not all(np.isnan(np.nanmin(dist, axis=0))):
# find the row with the smallest distance in it
min_dist = np.nanmin(dist, axis=1)
min_dist_row_idx = np.nanargmin(min_dist)
min_dist_row = dist[min_dist_row_idx, :]
# store match idx and distance for this row
match_idx[min_dist_row_idx] = np.nanargmin(min_dist_row)
match_dist[min_dist_row_idx] = min_dist[min_dist_row_idx]
# set the matched row and col to nan, so that it is not used again
dist[:, int(match_idx[min_dist_row_idx])] = np.nan
dist[min_dist_row_idx, :] = np.nan
match_dist = np.array([np.nan] * dist.shape[0]) # distances corresponding to the above matches

# use masked array rather than nan pass-through to avoid all-nan slice warnings
dist = np.ma.array(dist, mask=np.isnan(dist))

# repeat until all src or ref bands have been matched
while not dist.mask.all():
# find the row with the smallest distance in it
min_dist = dist.min(axis=1)
min_dist_row_idx = np.ma.argmin(min_dist)
min_dist_row = dist[min_dist_row_idx, :]
# store match idx and distance for this row
match_idx[min_dist_row_idx] = np.ma.argmin(min_dist_row)
match_dist[min_dist_row_idx] = min_dist[min_dist_row_idx]
# set the matched row and col to nan, so that it is not used again
dist[:, int(match_idx[min_dist_row_idx])] = np.ma.masked
dist[min_dist_row_idx, :] = np.ma.masked

return match_dist, match_idx

match_dist, match_idx = greedy_match(rel_dist)

# if any of the matched distances are greater than a threshold, raise an informative error,
# or log a warning, depending on `self._force`
# if any of the matched distances are greater than a threshold, raise an informative error
if any(match_dist > MatchedPairReader._max_rel_wavelength_diff):
err_idx = match_dist > MatchedPairReader._max_rel_wavelength_diff
src_err_band_names = list(src_band_names[err_idx])
Expand Down
7 changes: 0 additions & 7 deletions homonim/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,6 @@ def to_rio_dataset(
f'The length of indexes ({len(indexes)}) exceeds the number of bands in the '
f'RasterArray ({self.count})'
)
if rio_dataset.nodata is not None and (
self.nodata is None or not utils.nan_equals(self.nodata, rio_dataset.nodata)
):
warnings.warn(
f"The dataset nodata: {rio_dataset.nodata} does not match the RasterArray nodata: {self.nodata}",
category=ImageFormatWarning
)

if window is None:
# a window defining the region in the dataset corresponding to the RasterArray extents
Expand Down
6 changes: 3 additions & 3 deletions homonim/raster_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def _auto_block_shape(self, max_block_mem: float = np.inf) -> Tuple[int, int]:
proc_win = self._ref_win if self.proc_crs == ProcCrs.ref else self._src_win
# adjust max_block_mem to represent the size of a block in the highest resolution image, but scaled to the
# equivalent in proc_crs.
src_pix_area = np.product(np.abs(self._src_im.res))
ref_pix_area = np.product(np.abs(self._ref_im.res))
src_pix_area = np.prod(np.abs(self._src_im.res))
ref_pix_area = np.prod(np.abs(self._ref_im.res))
if self.proc_crs == ProcCrs.ref:
mem_scale = src_pix_area / ref_pix_area if ref_pix_area > src_pix_area else 1.
elif self.proc_crs == ProcCrs.src:
Expand All @@ -247,7 +247,7 @@ def _auto_block_shape(self, max_block_mem: float = np.inf) -> Tuple[int, int]:
block_shape = np.array((proc_win.height, proc_win.width)).astype('float')

# keep halving the block_shape along the longest dimension until it satisfies max_block_mem
while (np.product(block_shape) * dtype_size) > max_block_mem:
while (np.prod(block_shape) * dtype_size) > max_block_mem:
div_dim = np.argmax(block_shape)
block_shape[div_dim] /= 2

Expand Down
6 changes: 4 additions & 2 deletions homonim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def validate_kernel_shape(kernel_shape: Tuple[int, int], model: Model = Model.ga
if not np.all(np.mod(kernel_shape, 2) == 1):
raise ValueError('`kernel_shape` must be odd in both dimensions.')
if model == Model.gain_offset:
if np.product(kernel_shape) < 2:
if np.prod(kernel_shape) < 2:
raise ValueError('`kernel_shape` area should contain at least 2 elements for the gain-offset model.')
elif np.product(kernel_shape) < 25:
elif np.prod(kernel_shape) < 25:
warnings.warn(
'A `kernel_shape` of at least 25 elements is recommended for the gain-offset model.',
category=ConfigWarning
Expand Down Expand Up @@ -155,6 +155,8 @@ def overlap_for_kernel(kernel_shape: Tuple[int, int]) -> Tuple[int, int]:

def validate_threads(threads: int) -> int:
""" Parse number of threads parameter. """
# TODO: Memory increases ~linearly with number of threads, but does processing speed? The bottleneck is often
# file IO & I am not sure >2 threads as a default is justified.
_cpu_count = cpu_count()
threads = _cpu_count if threads == 0 else threads
if threads > _cpu_count:
Expand Down
2 changes: 1 addition & 1 deletion homonim/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.0'
__version__ = '0.4.1'
18 changes: 12 additions & 6 deletions tests/test_matched_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
"""

from pathlib import Path
from typing import Tuple, List, Dict
from typing import Tuple, List
import warnings

import numpy as np
import pytest
import rasterio as rio
from rasterio import Affine
from rasterio.windows import Window

from homonim import utils
from homonim.matched_pair import MatchedPairReader
from homonim.errors import HomonimWarning


@pytest.mark.parametrize(['file', 'bands', 'exp_bands', 'exp_band_names', 'exp_wavelengths'], [
Expand Down Expand Up @@ -104,9 +104,15 @@ def test_match(
""" Test matching of different source and reference files. """
src_file: Path = request.getfixturevalue(src_file)
ref_file: Path = request.getfixturevalue(ref_file)
with MatchedPairReader(src_file, ref_file, src_bands=src_bands, ref_bands=ref_bands, force=force) as matched_pair:
assert all(np.array(matched_pair.src_bands) == exp_src_bands)
assert all(np.array(matched_pair.ref_bands) == exp_ref_bands)

with warnings.catch_warnings():
# test there are no all-nan warnings by turning them RuntimeWarning into an error, while allowing
# HomonimWarning which sub-classes RuntimeWarning
warnings.simplefilter("error", category=RuntimeWarning)
warnings.simplefilter("default", category=HomonimWarning)
with MatchedPairReader(src_file, ref_file, src_bands=src_bands, ref_bands=ref_bands, force=force) as matched_pair:
assert all(np.array(matched_pair.src_bands) == exp_src_bands)
assert all(np.array(matched_pair.ref_bands) == exp_ref_bands)


def test_match_fewer_ref_bands_error(s2_ref_file, landsat_ref_file):
Expand Down

0 comments on commit 29b03d1

Please sign in to comment.