diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3999be8b04..3ce81859ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: - types-pkg-resources - types-PyYAML - types-requests - args: ["--python-version", "3.8", "--ignore-missing-imports"] + args: ["--python-version", "3.9", "--ignore-missing-imports"] - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: diff --git a/satpy/_compat.py b/satpy/_compat.py index b49b5a961b..aad2009db3 100644 --- a/satpy/_compat.py +++ b/satpy/_compat.py @@ -17,70 +17,7 @@ # satpy. If not, see . """Backports and compatibility fixes for satpy.""" -from threading import RLock - -_NOT_FOUND = object() - - -class CachedPropertyBackport: - """Backport of cached_property from Python-3.8. - - Source: https://github.com/python/cpython/blob/v3.8.0/Lib/functools.py#L930 - """ - - def __init__(self, func): # noqa - self.func = func - self.attrname = None - self.__doc__ = func.__doc__ - self.lock = RLock() - - def __set_name__(self, owner, name): # noqa - if self.attrname is None: - self.attrname = name - elif name != self.attrname: - raise TypeError( - "Cannot assign the same cached_property to two different names " - f"({self.attrname!r} and {name!r})." - ) - - def __get__(self, instance, owner=None): # noqa - if instance is None: - return self - if self.attrname is None: - raise TypeError( - "Cannot use cached_property instance without calling __set_name__ on it.") - try: - cache = instance.__dict__ # noqa - except AttributeError: # not all objects have __dict__ (e.g. class defines slots) - msg = ( - f"No '__dict__' attribute on {type(instance).__name__!r} " - f"instance to cache {self.attrname!r} property." - ) - raise TypeError(msg) from None - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - with self.lock: - # check if another thread filled cache while we awaited lock - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - val = self.func(instance) - try: - cache[self.attrname] = val - except TypeError: - msg = ( - f"The '__dict__' attribute on {type(instance).__name__!r} instance " - f"does not support item assignment for caching {self.attrname!r} property." - ) - raise TypeError(msg) from None - return val - - -try: - from functools import cached_property # type: ignore -except ImportError: - # for python < 3.8 - cached_property = CachedPropertyBackport # type: ignore - +from functools import cache, cached_property # noqa try: from numpy.typing import ArrayLike, DTypeLike # noqa @@ -88,9 +25,3 @@ def __get__(self, instance, owner=None): # noqa # numpy <1.20 from numpy import dtype as DTypeLike # noqa from numpy import ndarray as ArrayLike # noqa - - -try: - from functools import cache # type: ignore -except ImportError: - from functools import lru_cache as cache # noqa diff --git a/satpy/_config.py b/satpy/_config.py index 4abc00aba2..7a0d7aaac3 100644 --- a/satpy/_config.py +++ b/satpy/_config.py @@ -26,20 +26,9 @@ import tempfile from collections import OrderedDict from importlib.metadata import EntryPoint, entry_points -from pathlib import Path +from importlib.resources import files as impr_files from typing import Iterable -try: - from importlib.resources import files as impr_files # type: ignore -except ImportError: - # Python 3.8 - def impr_files(module_name: str) -> Path: - """Get path to module as a backport for Python 3.8.""" - from importlib.resources import path as impr_path - - with impr_path(module_name, "__init__.py") as pkg_init_path: - return pkg_init_path.parent - import appdirs from donfig import Config diff --git a/satpy/etc/readers/modis_l1b.yaml b/satpy/etc/readers/modis_l1b.yaml index d39ffbb99c..17bdf134bf 100644 --- a/satpy/etc/readers/modis_l1b.yaml +++ b/satpy/etc/readers/modis_l1b.yaml @@ -9,15 +9,6 @@ reader: sensors: [modis] default_datasets: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36] -navigations: - hdf_eos_geo: - description: MODIS navigation - file_type: hdf_eos_geo - latitude_key: Latitude - longitude_key: Longitude - nadir_resolution: [1000] - rows_per_scan: 10 - datasets: '1': name: '1' diff --git a/satpy/readers/hdfeos_base.py b/satpy/readers/hdfeos_base.py index f776256e89..56b15b626d 100644 --- a/satpy/readers/hdfeos_base.py +++ b/satpy/readers/hdfeos_base.py @@ -32,10 +32,9 @@ from satpy import DataID from satpy.readers.file_handlers import BaseFileHandler -from satpy.utils import get_legacy_chunk_size +from satpy.utils import normalize_low_res_chunks logger = logging.getLogger(__name__) -CHUNK_SIZE = get_legacy_chunk_size() def interpolate(clons, clats, csatz, src_resolution, dst_resolution): @@ -215,7 +214,8 @@ def load_dataset(self, dataset_name, is_category=False): from satpy.readers.hdf4_utils import from_sds dataset = self._read_dataset_in_file(dataset_name) - dask_arr = from_sds(dataset, chunks=CHUNK_SIZE) + chunks = self._chunks_for_variable(dataset) + dask_arr = from_sds(dataset, chunks=chunks) dims = ('y', 'x') if dask_arr.ndim == 2 else None data = xr.DataArray(dask_arr, dims=dims, attrs=dataset.attributes()) @@ -223,6 +223,32 @@ def load_dataset(self, dataset_name, is_category=False): return data + def _chunks_for_variable(self, hdf_dataset): + scan_length_250m = 40 + var_shape = hdf_dataset.info()[2] + res_multiplier = self._get_res_multiplier(var_shape) + num_nonyx_dims = len(var_shape) - 2 + return normalize_low_res_chunks( + (1,) * num_nonyx_dims + ("auto", -1), + var_shape, + (1,) * num_nonyx_dims + (scan_length_250m, -1), + (1,) * num_nonyx_dims + (res_multiplier, res_multiplier), + np.float32, + ) + + @staticmethod + def _get_res_multiplier(var_shape): + num_columns_to_multiplier = { + 271: 20, # 5km + 1354: 4, # 1km + 2708: 2, # 500m + 5416: 1, # 250m + } + for max_columns, res_multiplier in num_columns_to_multiplier.items(): + if var_shape[-1] <= max_columns: + return res_multiplier + return 1 + def _scale_and_mask_data_array(self, data, is_category=False): """Unscale byte data and mask invalid/fill values. diff --git a/satpy/readers/modis_l1b.py b/satpy/readers/modis_l1b.py index 5f0627b95d..1d0e209d57 100644 --- a/satpy/readers/modis_l1b.py +++ b/satpy/readers/modis_l1b.py @@ -78,10 +78,8 @@ from satpy.readers.hdf4_utils import from_sds from satpy.readers.hdfeos_base import HDFEOSBaseFileReader, HDFEOSGeoReader -from satpy.utils import get_legacy_chunk_size logger = logging.getLogger(__name__) -CHUNK_SIZE = get_legacy_chunk_size() class HDFEOSBandReader(HDFEOSBaseFileReader): @@ -118,7 +116,8 @@ def get_dataset(self, key, info): subdata = self.sd.select(var_name) var_attrs = subdata.attributes() uncertainty = self.sd.select(var_name + "_Uncert_Indexes") - array = xr.DataArray(from_sds(subdata, chunks=CHUNK_SIZE)[band_index, :, :], + chunks = self._chunks_for_variable(subdata) + array = xr.DataArray(from_sds(subdata, chunks=chunks)[band_index, :, :], dims=['y', 'x']).astype(np.float32) valid_range = var_attrs['valid_range'] valid_min = np.float32(valid_range[0]) @@ -214,7 +213,8 @@ def _mask_invalid(self, array, valid_min, valid_max): def _mask_uncertain_pixels(self, array, uncertainty, band_index): if not self._mask_saturated: return array - band_uncertainty = from_sds(uncertainty, chunks=CHUNK_SIZE)[band_index, :, :] + uncertainty_chunks = self._chunks_for_variable(uncertainty) + band_uncertainty = from_sds(uncertainty, chunks=uncertainty_chunks)[band_index, :, :] array = array.where(band_uncertainty < 15) return array diff --git a/satpy/resample.py b/satpy/resample.py index b124c84933..289371d8cb 100644 --- a/satpy/resample.py +++ b/satpy/resample.py @@ -143,6 +143,7 @@ import os import warnings from logging import getLogger +from math import lcm # type: ignore from weakref import WeakValueDictionary import dask @@ -157,14 +158,6 @@ from satpy.utils import PerformanceWarning, get_legacy_chunk_size -try: - from math import lcm # type: ignore -except ImportError: - def lcm(a, b): - """Get 'Least Common Multiple' with Python 3.8 compatibility.""" - from math import gcd - return abs(a * b) // gcd(a, b) - try: from pyresample.resampler import BaseResampler as PRBaseResampler except ImportError: diff --git a/satpy/tests/reader_tests/modis_tests/_modis_fixtures.py b/satpy/tests/reader_tests/modis_tests/_modis_fixtures.py index dfc8f0aec6..49331f5421 100644 --- a/satpy/tests/reader_tests/modis_tests/_modis_fixtures.py +++ b/satpy/tests/reader_tests/modis_tests/_modis_fixtures.py @@ -62,7 +62,7 @@ def _shape_for_resolution(resolution: int) -> tuple[int, int]: return factor * shape_1km[0], factor * shape_1km[1] -def _generate_lonlat_data(resolution: int) -> np.ndarray: +def _generate_lonlat_data(resolution: int) -> tuple[np.ndarray, np.ndarray]: shape = _shape_for_resolution(resolution) lat = np.repeat(np.linspace(35., 45., shape[0])[:, None], shape[1], 1) lat *= np.linspace(0.9, 1.1, shape[1]) diff --git a/satpy/tests/reader_tests/modis_tests/test_modis_l1b.py b/satpy/tests/reader_tests/modis_tests/test_modis_l1b.py index 56e8687844..53f0ca46ce 100644 --- a/satpy/tests/reader_tests/modis_tests/test_modis_l1b.py +++ b/satpy/tests/reader_tests/modis_tests/test_modis_l1b.py @@ -51,6 +51,18 @@ def _check_shared_metadata(data_arr): assert "rows_per_scan" in data_arr.attrs assert isinstance(data_arr.attrs["rows_per_scan"], int) assert data_arr.attrs['reader'] == 'modis_l1b' + assert "resolution" in data_arr.attrs + res = data_arr.attrs["resolution"] + if res == 5000: + assert data_arr.chunks == ((2, 2, 2), (data_arr.shape[1],)) + elif res == 1000: + assert data_arr.chunks == ((10, 10, 10), (data_arr.shape[1],)) + elif res == 500: + assert data_arr.chunks == ((20, 20, 20), (data_arr.shape[1],)) + elif res == 250: + assert data_arr.chunks == ((40, 40, 40), (data_arr.shape[1],)) + else: + raise ValueError(f"Unexpected resolution: {res}") def _load_and_check_geolocation(scene, resolution, exp_res, exp_shape, has_res, @@ -147,7 +159,8 @@ def test_load_longitude_latitude(self, input_files, has_5km, has_500, has_250, d shape_500m = _shape_for_resolution(500) shape_250m = _shape_for_resolution(250) default_shape = _shape_for_resolution(default_res) - with dask.config.set(scheduler=CustomScheduler(max_computes=1 + has_5km + has_500 + has_250)): + scheduler = CustomScheduler(max_computes=1 + has_5km + has_500 + has_250) + with dask.config.set({'scheduler': scheduler, 'array.chunk-size': '1 MiB'}): _load_and_check_geolocation(scene, "*", default_res, default_shape, True) _load_and_check_geolocation(scene, 5000, 5000, shape_5km, has_5km) _load_and_check_geolocation(scene, 500, 500, shape_500m, has_500) @@ -157,7 +170,8 @@ def test_load_sat_zenith_angle(self, modis_l1b_nasa_mod021km_file): """Test loading satellite zenith angle band.""" scene = Scene(reader='modis_l1b', filenames=modis_l1b_nasa_mod021km_file) dataset_name = 'satellite_zenith_angle' - scene.load([dataset_name]) + with dask.config.set({'array.chunk-size': '1 MiB'}): + scene.load([dataset_name]) dataset = scene[dataset_name] assert dataset.shape == _shape_for_resolution(1000) assert dataset.attrs['resolution'] == 1000 @@ -167,7 +181,8 @@ def test_load_vis(self, modis_l1b_nasa_mod021km_file): """Test loading visible band.""" scene = Scene(reader='modis_l1b', filenames=modis_l1b_nasa_mod021km_file) dataset_name = '1' - scene.load([dataset_name]) + with dask.config.set({'array.chunk-size': '1 MiB'}): + scene.load([dataset_name]) dataset = scene[dataset_name] assert dataset[0, 0] == 300.0 assert dataset.shape == _shape_for_resolution(1000) @@ -180,7 +195,8 @@ def test_load_vis_saturation(self, mask_saturated, modis_l1b_nasa_mod021km_file) scene = Scene(reader='modis_l1b', filenames=modis_l1b_nasa_mod021km_file, reader_kwargs={"mask_saturated": mask_saturated}) dataset_name = '2' - scene.load([dataset_name]) + with dask.config.set({'array.chunk-size': '1 MiB'}): + scene.load([dataset_name]) dataset = scene[dataset_name] assert dataset.shape == _shape_for_resolution(1000) assert dataset.attrs['resolution'] == 1000 diff --git a/satpy/tests/test_compat.py b/satpy/tests/test_compat.py deleted file mode 100644 index f084f88e53..0000000000 --- a/satpy/tests/test_compat.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright (c) 2022 Satpy developers -# -# This file is part of satpy. -# -# satpy is free software: you can redistribute it and/or modify it under the -# terms of the GNU General Public License as published by the Free Software -# Foundation, either version 3 of the License, or (at your option) any later -# version. -# -# satpy is distributed in the hope that it will be useful, but WITHOUT ANY -# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR -# A PARTICULAR PURPOSE. See the GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License along with -# satpy. If not, see . -"""Test backports and compatibility fixes.""" - -import gc - -from satpy._compat import CachedPropertyBackport - - -class ClassWithCachedProperty: # noqa - def __init__(self, x): # noqa - self.x = x - - @CachedPropertyBackport - def property(self): # noqa - return 2 * self.x - - -def test_cached_property_backport(): - """Test cached property backport.""" - c = ClassWithCachedProperty(1) - assert c.property == 2 - - -def test_cached_property_backport_releases_memory(): - """Test that cached property backport releases memory.""" - c1 = ClassWithCachedProperty(2) - del c1 - instances = [ - obj for obj in gc.get_objects() - if isinstance(obj, ClassWithCachedProperty) - ] - assert len(instances) == 0 diff --git a/satpy/tests/test_utils.py b/satpy/tests/test_utils.py index 56dbe25324..ef6a359cdd 100644 --- a/satpy/tests/test_utils.py +++ b/satpy/tests/test_utils.py @@ -21,6 +21,7 @@ import typing import unittest import warnings +from math import sqrt from unittest import mock import dask.array as da @@ -44,182 +45,100 @@ # - caplog -class TestUtils(unittest.TestCase): - """Testing utils.""" +class TestGeoUtils: + """Testing geo-related utility functions.""" - def test_lonlat2xyz(self): - """Test the lonlat2xyz function.""" - x__, y__, z__ = lonlat2xyz(0, 0) - self.assertAlmostEqual(x__, 1) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = lonlat2xyz(90, 0) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 1) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = lonlat2xyz(0, 90) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 1) - - x__, y__, z__ = lonlat2xyz(180, 0) - self.assertAlmostEqual(x__, -1) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = lonlat2xyz(-90, 0) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, -1) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = lonlat2xyz(0, -90) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, -1) - - x__, y__, z__ = lonlat2xyz(0, 45) - self.assertAlmostEqual(x__, np.sqrt(2) / 2) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, np.sqrt(2) / 2) - - x__, y__, z__ = lonlat2xyz(0, 60) - self.assertAlmostEqual(x__, np.sqrt(1) / 2) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, np.sqrt(3) / 2) - - def test_angle2xyz(self): + @pytest.mark.parametrize( + ("lonlat", "xyz"), + [ + ((0, 0), (1, 0, 0)), + ((90, 0), (0, 1, 0)), + ((0, 90), (0, 0, 1)), + ((180, 0), (-1, 0, 0)), + ((-90, 0), (0, -1, 0)), + ((0, -90), (0, 0, -1)), + ((0, 45), (sqrt(2) / 2, 0, sqrt(2) / 2)), + ((0, 60), (sqrt(1) / 2, 0, sqrt(3) / 2)), + ], + ) + def test_lonlat2xyz(self, lonlat, xyz): """Test the lonlat2xyz function.""" - x__, y__, z__ = angle2xyz(0, 0) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 1) - - x__, y__, z__ = angle2xyz(90, 0) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 1) - - x__, y__, z__ = angle2xyz(0, 90) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 1) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = angle2xyz(180, 0) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 1) - - x__, y__, z__ = angle2xyz(-90, 0) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 1) - - x__, y__, z__ = angle2xyz(0, -90) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, -1) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = angle2xyz(90, 90) - self.assertAlmostEqual(x__, 1) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = angle2xyz(-90, 90) - self.assertAlmostEqual(x__, -1) - self.assertAlmostEqual(y__, 0) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = angle2xyz(180, 90) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, -1) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = angle2xyz(0, -90) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, -1) - self.assertAlmostEqual(z__, 0) - - x__, y__, z__ = angle2xyz(0, 45) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, np.sqrt(2) / 2) - self.assertAlmostEqual(z__, np.sqrt(2) / 2) - - x__, y__, z__ = angle2xyz(0, 60) - self.assertAlmostEqual(x__, 0) - self.assertAlmostEqual(y__, np.sqrt(3) / 2) - self.assertAlmostEqual(z__, np.sqrt(1) / 2) - - def test_xyz2lonlat(self): - """Test xyz2lonlat.""" - lon, lat = xyz2lonlat(1, 0, 0) - self.assertAlmostEqual(lon, 0) - self.assertAlmostEqual(lat, 0) - - lon, lat = xyz2lonlat(0, 1, 0) - self.assertAlmostEqual(lon, 90) - self.assertAlmostEqual(lat, 0) + x__, y__, z__ = lonlat2xyz(*lonlat) + assert x__ == pytest.approx(xyz[0]) + assert y__ == pytest.approx(xyz[1]) + assert z__ == pytest.approx(xyz[2]) - lon, lat = xyz2lonlat(0, 0, 1, asin=True) - self.assertAlmostEqual(lon, 0) - self.assertAlmostEqual(lat, 90) - - lon, lat = xyz2lonlat(0, 0, 1) - self.assertAlmostEqual(lon, 0) - self.assertAlmostEqual(lat, 90) + @pytest.mark.parametrize( + ("azizen", "xyz"), + [ + ((0, 0), (0, 0, 1)), + ((90, 0), (0, 0, 1)), + ((0, 90), (0, 1, 0)), + ((180, 0), (0, 0, 1)), + ((-90, 0), (0, 0, 1)), + ((0, -90), (0, -1, 0)), + ((90, 90), (1, 0, 0)), + ((-90, 90), (-1, 0, 0)), + ((180, 90), (0, -1, 0)), + ((0, -90), (0, -1, 0)), + ((0, 45), (0, sqrt(2) / 2, sqrt(2) / 2)), + ((0, 60), (0, sqrt(3) / 2, sqrt(1) / 2)), + ], + ) + def test_angle2xyz(self, azizen, xyz): + """Test the angle2xyz function.""" + x__, y__, z__ = angle2xyz(*azizen) + assert x__ == pytest.approx(xyz[0]) + assert y__ == pytest.approx(xyz[1]) + assert z__ == pytest.approx(xyz[2]) - lon, lat = xyz2lonlat(np.sqrt(2) / 2, np.sqrt(2) / 2, 0) - self.assertAlmostEqual(lon, 45) - self.assertAlmostEqual(lat, 0) + @pytest.mark.parametrize( + ("xyz", "asin", "lonlat"), + [ + ((1, 0, 0), False, (0, 0)), + ((0, 1, 0), False, (90, 0)), + ((0, 0, 1), True, (0, 90)), + ((0, 0, 1), False, (0, 90)), + ((sqrt(2) / 2, sqrt(2) / 2, 0), False, (45, 0)), + ], + ) + def test_xyz2lonlat(self, xyz, asin, lonlat): + """Test xyz2lonlat.""" + lon, lat = xyz2lonlat(*xyz, asin=asin) + assert lon == pytest.approx(lonlat[0]) + assert lat == pytest.approx(lonlat[1]) - def test_xyz2angle(self): + @pytest.mark.parametrize( + ("xyz", "acos", "azizen"), + [ + ((1, 0, 0), False, (90, 90)), + ((0, 1, 0), False, (0, 90)), + ((0, 0, 1), False, (0, 0)), + ((0, 0, 1), True, (0, 0)), + ((sqrt(2) / 2, sqrt(2) / 2, 0), False, (45, 90)), + ((-1, 0, 0), False, (-90, 90)), + ((0, -1, 0), False, (180, 90)), + ], + ) + def test_xyz2angle(self, xyz, acos, azizen): """Test xyz2angle.""" - azi, zen = xyz2angle(1, 0, 0) - self.assertAlmostEqual(azi, 90) - self.assertAlmostEqual(zen, 90) - - azi, zen = xyz2angle(0, 1, 0) - self.assertAlmostEqual(azi, 0) - self.assertAlmostEqual(zen, 90) - - azi, zen = xyz2angle(0, 0, 1) - self.assertAlmostEqual(azi, 0) - self.assertAlmostEqual(zen, 0) - - azi, zen = xyz2angle(0, 0, 1, acos=True) - self.assertAlmostEqual(azi, 0) - self.assertAlmostEqual(zen, 0) - - azi, zen = xyz2angle(np.sqrt(2) / 2, np.sqrt(2) / 2, 0) - self.assertAlmostEqual(azi, 45) - self.assertAlmostEqual(zen, 90) + azi, zen = xyz2angle(*xyz, acos=acos) + assert azi == pytest.approx(azi) + assert zen == pytest.approx(zen) - azi, zen = xyz2angle(-1, 0, 0) - self.assertAlmostEqual(azi, -90) - self.assertAlmostEqual(zen, 90) - - azi, zen = xyz2angle(0, -1, 0) - self.assertAlmostEqual(azi, 180) - self.assertAlmostEqual(zen, 90) - - def test_proj_units_to_meters(self): + @pytest.mark.parametrize( + ("prj", "exp_prj"), + [ + ("+asd=123123123123", "+asd=123123123123"), + ("+a=6378.137", "+a=6378137.000"), + ("+a=6378.137 +units=km", "+a=6378137.000"), + ("+a=6378.137 +b=6378.137", "+a=6378137.000 +b=6378137.000"), + ("+a=6378.137 +b=6378.137 +h=35785.863", "+a=6378137.000 +b=6378137.000 +h=35785863.000"), + ], + ) + def test_proj_units_to_meters(self, prj, exp_prj): """Test proj units to meters conversion.""" - prj = '+asd=123123123123' - res = proj_units_to_meters(prj) - self.assertEqual(res, prj) - prj = '+a=6378.137' - res = proj_units_to_meters(prj) - self.assertEqual(res, '+a=6378137.000') - prj = '+a=6378.137 +units=km' - res = proj_units_to_meters(prj) - self.assertEqual(res, '+a=6378137.000') - prj = '+a=6378.137 +b=6378.137' - res = proj_units_to_meters(prj) - self.assertEqual(res, '+a=6378137.000 +b=6378137.000') - prj = '+a=6378.137 +b=6378.137 +h=35785.863' - res = proj_units_to_meters(prj) - self.assertEqual(res, '+a=6378137.000 +b=6378137.000 +h=35785863.000') + assert proj_units_to_meters(prj) == exp_prj class TestGetSatPos: @@ -273,7 +192,7 @@ def test_get_satpos(self, included_prefixes, preference, expected_result): "attrs", ( {}, - {'orbital_parameters': {'projection_longitude': 1}}, + {'orbital_parameters': {'projection_longitude': 1}}, {'satellite_altitude': 1} ) ) @@ -288,16 +207,17 @@ def test_get_satpos_from_satname(self, caplog): import pyorbital.tlefile data_arr = xr.DataArray( - (), - attrs={ - "platform_name": "Meteosat-42", - "sensor": "irives", - "start_time": datetime.datetime(2031, 11, 20, 19, 18, 17)}) + (), + attrs={ + "platform_name": "Meteosat-42", + "sensor": "irives", + "start_time": datetime.datetime(2031, 11, 20, 19, 18, 17) + }) with mock.patch("pyorbital.tlefile.read") as plr: plr.return_value = pyorbital.tlefile.Tle( - "Meteosat-42", - line1="1 40732U 15034A 22011.84285506 .00000004 00000+0 00000+0 0 9995", - line2="2 40732 0.2533 325.0106 0000976 118.8734 330.4058 1.00272123 23817") + "Meteosat-42", + line1="1 40732U 15034A 22011.84285506 .00000004 00000+0 00000+0 0 9995", + line2="2 40732 0.2533 325.0106 0000976 118.8734 330.4058 1.00272123 23817") with caplog.at_level(logging.WARNING): (lon, lat, alt) = get_satpos(data_arr, use_tle=True) assert "Orbital parameters missing from metadata" in caplog.text @@ -319,13 +239,15 @@ def test_make_fake_scene(): assert make_fake_scene({}).keys() == [] sc = make_fake_scene({ - "six": np.arange(25).reshape(5, 5)}) + "six": np.arange(25).reshape(5, 5) + }) assert len(sc.keys()) == 1 assert sc.keys().pop()['name'] == "six" assert sc["six"].attrs["area"].shape == (5, 5) sc = make_fake_scene({ - "seven": np.arange(3*7).reshape(3, 7), - "eight": np.arange(3*8).reshape(3, 8)}, + "seven": np.arange(3 * 7).reshape(3, 7), + "eight": np.arange(3 * 8).reshape(3, 8) + }, daskify=True, area=False, common_attrs={"repetency": "fourteen hundred per centimetre"}) @@ -335,9 +257,10 @@ def test_make_fake_scene(): assert isinstance(sc["seven"].data, da.Array) sc = make_fake_scene({ "nine": xr.DataArray( - np.arange(2*9).reshape(2, 9), + np.arange(2 * 9).reshape(2, 9), dims=("y", "x"), - attrs={"please": "preserve", "answer": 42})}, + attrs={"please": "preserve", "answer": 42}) + }, common_attrs={"bad words": "semprini bahnhof veerooster winterbanden"}) assert sc["nine"].attrs.keys() >= {"please", "answer", "bad words", "area"} @@ -376,6 +299,7 @@ def depwarn(): DeprecationWarning, stacklevel=2 ) + warnings.filterwarnings("ignore", category=DeprecationWarning) debug_on(False) filts_before = warnings.filters.copy() @@ -497,6 +421,53 @@ def test_get_legacy_chunk_size(): assert get_legacy_chunk_size() == 2048 +@pytest.mark.parametrize( + ("chunks", "shape", "previous_chunks", "lr_mult", "chunk_dtype", "exp_result"), + [ + # 1km swath + (("auto", -1), (1000, 3200), (40, 40), (4, 4), np.float32, (160, -1)), + # 5km swath + (("auto", -1), (1000 // 5, 3200 // 5), (40, 40), (20, 20), np.float32, (160 / 5, -1)), + # 250m swath + (("auto", -1), (1000 * 4, 3200 * 4), (40, 40), (1, 1), np.float32, (160 * 4, -1)), + # 1km area (ABI chunk 226): + (("auto", "auto"), (21696 // 2, 21696 // 2), (226*4, 226*4), (2, 2), np.float32, (1356, 1356)), + # 1km area (64-bit) + (("auto", "auto"), (21696 // 2, 21696 // 2), (226*4, 226*4), (2, 2), np.float64, (904, 904)), + # 3km area + (("auto", "auto"), (21696 // 3, 21696 // 3), (226*4, 226*4), (6, 6), np.float32, (452, 452)), + # 500m area + (("auto", "auto"), (21696, 21696), (226*4, 226*4), (1, 1), np.float32, (1356 * 2, 1356 * 2)), + # 500m area (64-bit) + (("auto", "auto"), (21696, 21696), (226*4, 226*4), (1, 1), np.float64, (904 * 2, 904 * 2)), + # 250m swath with bands: + ((1, "auto", -1), (7, 1000 * 4, 3200 * 4), (1, 40, 40), (1, 1, 1), np.float32, (1, 160 * 4, -1)), + # lots of dimensions: + ((1, 1, "auto", -1), (1, 7, 1000, 3200), (1, 1, 40, 40), (1, 1, 1, 1), np.float32, (1, 1, 1000, -1)), + ], +) +def test_resolution_chunking(chunks, shape, previous_chunks, lr_mult, chunk_dtype, exp_result): + """Test normalize_low_res_chunks helper function.""" + import dask.config + + from satpy.utils import normalize_low_res_chunks + + with dask.config.set({"array.chunk-size": "32MiB"}): + chunk_results = normalize_low_res_chunks( + chunks, + shape, + previous_chunks, + lr_mult, + chunk_dtype, + ) + assert chunk_results == exp_result + for chunk_size in chunk_results: + assert isinstance(chunk_size[0], int) if isinstance(chunk_size, tuple) else isinstance(chunk_size, int) + + # make sure the chunks are understandable by dask + da.zeros(shape, dtype=chunk_dtype, chunks=chunk_results) + + def test_convert_remote_files_to_fsspec_local_files(): """Test convertion of remote files to fsspec objects. @@ -615,19 +586,21 @@ def test_find_in_ancillary(): """Test finding a dataset in ancillary variables.""" from satpy.utils import find_in_ancillary index_finger = xr.DataArray( - data=np.arange(25).reshape(5, 5), - dims=("y", "x"), - attrs={"name": "index-finger"}) + data=np.arange(25).reshape(5, 5), + dims=("y", "x"), + attrs={"name": "index-finger"}) ring_finger = xr.DataArray( - data=np.arange(25).reshape(5, 5), - dims=("y", "x"), - attrs={"name": "ring-finger"}) + data=np.arange(25).reshape(5, 5), + dims=("y", "x"), + attrs={"name": "ring-finger"}) hand = xr.DataArray( - data=np.arange(25).reshape(5, 5), - dims=("y", "x"), - attrs={"name": "hand", - "ancillary_variables": [index_finger, index_finger, ring_finger]}) + data=np.arange(25).reshape(5, 5), + dims=("y", "x"), + attrs={ + "name": "hand", + "ancillary_variables": [index_finger, index_finger, ring_finger] + }) assert find_in_ancillary(hand, "ring-finger") is ring_finger with pytest.raises( diff --git a/satpy/utils.py b/satpy/utils.py index a9785a544a..67150fed9d 100644 --- a/satpy/utils.py +++ b/satpy/utils.py @@ -26,7 +26,7 @@ import warnings from contextlib import contextmanager from copy import deepcopy -from typing import Mapping, Optional +from typing import Literal, Mapping, Optional from urllib.parse import urlparse import dask.utils @@ -35,6 +35,8 @@ import yaml from yaml import BaseLoader, UnsafeLoader +from satpy._compat import DTypeLike + _is_logging_on = False TRACE_LEVEL = 5 @@ -631,6 +633,83 @@ def _get_pytroll_chunk_size(): return None +def normalize_low_res_chunks( + chunks: tuple[int | Literal["auto"], ...], + input_shape: tuple[int, ...], + previous_chunks: tuple[int, ...], + low_res_multipliers: tuple[int, ...], + input_dtype: DTypeLike, +) -> tuple[int, ...]: + """Compute dask chunk sizes based on data resolution. + + First, chunks are computed for the highest resolution version of the data. + This is done by multiplying the input array shape by the + ``low_res_multiplier`` and then using Dask's utility functions and + configuration to produce a chunk size to fit into a specific number of + bytes. See :doc:`dask:array-chunks` for more information. + Next, the same multiplier is used to reduce the high resolution chunk sizes + to the lower resolution of the input data. The end result of reading + multiple resolutions of data is that each dask chunk covers the same + geographic region. This also means replicating or aggregating one + resolution and then combining arrays should not require any rechunking. + + Args: + chunks: Requested chunk size for each dimension. This is passed + directly to dask. Use ``"auto"`` for dimensions that should have + chunks determined for them, ``-1`` for dimensions that should be + whole (not chunked), and ``1`` or any other positive integer for + dimensions that have a known chunk size beforehand. + input_shape: Shape of the array to compute dask chunk size for. + previous_chunks: Any previous chunking or structure of the data. This + can also be thought of as the smallest number of high (fine) resolution + elements that make up a single "unit" or chunk of data. This could + be a multiple or factor of the scan size for some instruments and/or + could be based on the on-disk chunk size. This value ensures that + chunks are aligned to the underlying data structure for best + performance. On-disk chunk sizes should be multiplied by the + largest low resolution multiplier if it is the same between all + files (ex. 500m file has 226 chunk size, 1km file has 226 chunk + size, etc).. Otherwise, the resulting low resolution chunks may + not be aligned to the on-disk chunks. For example, if dask decides + on a chunk size of 226 * 3 for 500m data, that becomes 226 * 3 / 2 + for 1km data which is not aligned to the on-disk chunk size of 226. + low_res_multipliers: Number of high (fine) resolution pixels that fit + in a single low (coarse) resolution pixel. + input_dtype: Dtype for the final unscaled array. This is usually + 32-bit float (``np.float32``) or 64-bit float (``np.float64``) + for non-category data. If this doesn't represent the final data + type of the data then the final size of chunks in memory will not + match the user's request via dask's ``array.chunk-size`` + configuration. Sometimes it is useful to keep this as a single + dtype for all reading functionality (ex. ``np.float32``) in order + to keep all read variable chunks the same size regardless of dtype. + + Returns: + A tuple where each element is the chunk size for that axis/dimension. + + """ + if any(len(input_shape) != len(param) for param in (low_res_multipliers, chunks, previous_chunks)): + raise ValueError("Input shape, low res multipliers, chunks, and previous chunks must all be the same size") + high_res_shape = tuple(dim_size * lr_mult for dim_size, lr_mult in zip(input_shape, low_res_multipliers)) + chunks_for_high_res = dask.array.core.normalize_chunks( + chunks, + shape=high_res_shape, + dtype=input_dtype, + previous_chunks=previous_chunks, + ) + low_res_chunks: list[int] = [] + for req_chunks, hr_chunks, prev_chunks, lr_mult in zip( + chunks, + chunks_for_high_res, + previous_chunks, low_res_multipliers + ): + if req_chunks != "auto": + low_res_chunks.append(req_chunks) + continue + low_res_chunks.append(round(max(hr_chunks[0] / lr_mult, prev_chunks / lr_mult))) + return tuple(low_res_chunks) + + def convert_remote_files_to_fsspec(filenames, storage_options=None): """Check filenames for transfer protocols, convert to FSFile objects if possible.""" if storage_options is None: