diff --git a/odc/geo/_cog.py b/odc/geo/_cog.py index 3004a34e..599ce447 100644 --- a/odc/geo/_cog.py +++ b/odc/geo/_cog.py @@ -17,8 +17,8 @@ from rasterio.shutil import copy as rio_copy # pylint: disable=no-name-in-module from .geobox import GeoBox -from .math import align_up -from .types import SomeShape, shape_, wh_ +from .math import align_down_pow2, align_up +from .types import Shape2d, SomeShape, shape_, wh_ from .warp import resampling_s2rio # pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements @@ -60,6 +60,33 @@ def _adjust_blocksize(block: int, dim: int = 0) -> int: return align_up(block, 16) +def _num_overviews(block: int, dim: int) -> int: + c = 0 + while block < dim: + dim = dim // 2 + c += 1 + return c + + +def _compute_cog_spec( + data_shape: SomeShape, + tile_shape: SomeShape, + *, + max_pad: Optional[int] = None, +) -> Tuple[Shape2d, Shape2d, int]: + data_shape = shape_(data_shape) + tile_shape = shape_(shape_(tile_shape).map(_adjust_blocksize)) + n1, n2 = (_num_overviews(b, dim) for dim, b in zip(data_shape.xy, tile_shape.xy)) + n = max(n1, n2) + pad = 2**n + if max_pad is not None and max_pad < pad: + pad = 0 if max_pad == 0 else align_down_pow2(max_pad) + + if pad > 0: + data_shape = shape_(data_shape.map(lambda d: align_up(d, pad))) + return (data_shape, tile_shape, n) + + def _default_cog_opts( *, blocksize: int = 512, shape: SomeShape = (0, 0), is_float: bool = False, **other ) -> Dict[str, Any]: @@ -101,7 +128,7 @@ def _write_cog( ovr_blocksize: Optional[int] = None, use_windowed_writes: bool = False, intermediate_compression: Union[bool, str, Dict[str, Any]] = False, - **extra_rio_opts + **extra_rio_opts, ) -> Union[Path, bytes]: if blocksize is None: blocksize = 512 @@ -233,7 +260,7 @@ def write_cog( overview_levels: Optional[List[int]] = None, use_windowed_writes: bool = False, intermediate_compression: Union[bool, str, Dict[str, Any]] = False, - **extra_rio_opts + **extra_rio_opts, ) -> Union[Path, bytes]: """ Save ``xarray.DataArray`` to a file in Cloud Optimized GeoTiff format. @@ -318,7 +345,7 @@ def to_cog( overview_levels: Optional[List[int]] = None, use_windowed_writes: bool = False, intermediate_compression: Union[bool, str, Dict[str, Any]] = False, - **extra_rio_opts + **extra_rio_opts, ) -> bytes: """ Compress ``xarray.DataArray`` into Cloud Optimized GeoTiff bytes in memory. @@ -387,7 +414,7 @@ def write_cog_layers( ovr_blocksize: Optional[int] = None, intermediate_compression: Union[bool, str, Dict[str, Any]] = False, use_windowed_writes: bool = False, - **extra_rio_opts + **extra_rio_opts, ) -> Union[Path, bytes, None]: """ Write COG from externally computed overviews. diff --git a/tests/test_cog.py b/tests/test_cog.py index 44556a8e..ce61d000 100644 --- a/tests/test_cog.py +++ b/tests/test_cog.py @@ -1,3 +1,8 @@ +from typing import Optional, Tuple + +import pytest + +from odc.geo._cog import _compute_cog_spec, _num_overviews from odc.geo.gridspec import GridSpec from odc.geo.xr import xr_zeros @@ -37,3 +42,48 @@ def test_write_cog_ovr(): img_bytes2 = img.odc.write_cog(":mem:", blocksize=32, overviews=ovrs) assert len(img_bytes) == len(img_bytes2) + + +@pytest.mark.parametrize( + ["block", "dim", "n_expect"], + [ + (2**5, 2**10, 5), + (2**5, 2**10 - 1, 5), + (2**5, 2**10 + 1, 5), + (256, 78, 0), + (1024, 1_000_000, -1), + (512, 3_040_000, -1), + ], +) +def test_num_overviews(block: int, dim: int, n_expect: int): + if n_expect >= 0: + assert _num_overviews(block, dim) == n_expect + else: + n = _num_overviews(block, dim) + assert dim // (2**n) <= block + + +@pytest.mark.parametrize( + ("shape", "tshape", "max_pad"), + [ + [(1024, 2048), (256, 128), None], + ], +) +def test_cog_spec( + shape: Tuple[int, int], + tshape: Tuple[int, int], + max_pad: Optional[int], +): + _shape, _tshape, nlevels = _compute_cog_spec(shape, tshape, max_pad=max_pad) + assert _shape[0] >= shape[0] + assert _shape[1] >= shape[1] + assert _tshape[0] >= tshape[0] + assert _tshape[1] >= tshape[1] + assert _tshape[0] % 16 == 0 + assert _tshape[1] % 16 == 0 + + assert max(_shape) // (2**nlevels) <= max(tshape) + + if max_pad is not None: + assert _shape[0] - shape[0] <= max_pad + assert _shape[1] - shape[1] <= max_pad