Skip to content

Commit

Permalink
Implemented SYX compressor mode for COG
Browse files Browse the repository at this point in the history
support multi-band inputs in planar config
  • Loading branch information
Kirill888 committed Oct 3, 2023
1 parent e37ce3a commit 5d23028
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 35 deletions.
10 changes: 7 additions & 3 deletions odc/geo/cog/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ def num_tiles(self):
ny, nx = self.chunked.yx
return self.num_planes * ny * nx

def tidx(self) -> Iterator[Tuple[int, int, int]]:
"""``[(plane_idx, iy, ix), ...]``"""
yield from np.ndindex((self.num_planes, *self.chunked.yx))
def tidx(self, sample_idx: Optional[int] = None) -> Iterator[Tuple[int, int, int]]:
"""``[([sample|plane]_idx, iy, ix), ...]``"""
if sample_idx is not None:
assert sample_idx < self.num_planes
yield from ((sample_idx, y, x) for y, x in np.ndindex(self.chunked.yx))
else:
yield from np.ndindex((self.num_planes, *self.chunked.yx))

def flat_tile_idx(self, idx: Tuple[int, int, int]) -> int:
"""Convert from sample,iy,ix to flat tile index."""
Expand Down
121 changes: 89 additions & 32 deletions odc/geo/cog/_tifffile.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,12 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]:
return meta, buf.getbuffer()


def _cog_block_compressor(
def _cog_block_compressor_yxs(
block: np.ndarray,
*,
tile_shape: Tuple[int, ...] = (),
encoder: Any = None,
predictor: Any = None,
axis: int = 1,
fill_value: Union[float, int] = 0,
**kw,
) -> bytes:
Expand All @@ -175,7 +174,7 @@ def _cog_block_compressor(
block = np.pad(block, pad, "constant", constant_values=(fill_value,))

if predictor is not None:
block = predictor(block, axis=axis)
block = predictor(block, axis=1)
if encoder:
try:
return encoder(block, **kw)
Expand All @@ -185,16 +184,52 @@ def _cog_block_compressor(
return bytes(block.data)


def _mk_tile_compressor(meta: CogMeta) -> Callable[[np.ndarray], bytes]:
def _cog_block_compressor_syx(
block: np.ndarray,
*,
tile_shape: Tuple[int, int] = (0, 0),
encoder: Any = None,
predictor: Any = None,
fill_value: Union[float, int] = 0,
sample_idx: int = 0,
**kw,
) -> bytes:
assert isinstance(block, np.ndarray)

if block.ndim == 2:
pass
elif block.shape[0] == 1:
block = block[0, :, :]
else:
block = block[sample_idx, :, :]

assert block.ndim == 2
if tile_shape != block.shape:
pad = tuple((0, want - have) for want, have in zip(tile_shape, block.shape))
block = np.pad(block, pad, "constant", constant_values=(fill_value,))

if predictor is not None:
block = predictor(block, axis=1)

if encoder:
try:
return encoder(block, **kw)
except Exception: # pylint: disable=broad-except
return b""

return bytes(block.data)


def _mk_tile_compressor(
meta: CogMeta, sample_idx: int = 0
) -> Callable[[np.ndarray], bytes]:
# pylint: disable=import-outside-toplevel
have.check_or_error("tifffile")
from tifffile import TIFF

tile_shape = meta.chunks
encoder = TIFF.COMPRESSORS[meta.compression]

# TODO: handle SYX in planar mode
axis = 1
predictor = None
if meta.predictor != 1:
predictor = TIFF.PREDICTORS[meta.predictor]
Expand All @@ -203,12 +238,22 @@ def _mk_tile_compressor(meta: CogMeta) -> Callable[[np.ndarray], bytes]:
if meta.nodata is not None:
fill_value = float(meta.nodata) if isinstance(meta.nodata, str) else meta.nodata

if meta.axis == "SYX":
return partial(
_cog_block_compressor_syx,
tile_shape=meta.tile.yx,
encoder=encoder,
predictor=predictor,
fill_value=fill_value,
sample_idx=sample_idx,
**meta.compressionargs,
)

return partial(
_cog_block_compressor,
_cog_block_compressor_yxs,
tile_shape=tile_shape,
encoder=encoder,
predictor=predictor,
axis=axis,
fill_value=fill_value,
**meta.compressionargs,
)
Expand Down Expand Up @@ -237,19 +282,29 @@ def _compress_tiles(

from .._interop import is_dask_collection

# TODO: deal with SYX planar data
assert meta.num_planes == 1
assert meta.axis in ("YX", "YXS")
assert meta.num_planes == 1
src_ydim = 0 # for now assume Y,X or Y,X,S

encoder = _mk_tile_compressor(meta)
data = xx.data
assert is_dask_collection(data)

if meta.axis == "SYX":
src_ydim = 1
if data.ndim == 2:
_chunks: Tuple[int, ...] = meta.tile.yx
elif len(data.chunks[0]) == 1:
# if 1 single chunk with all "samples", keep it that way
_chunks = (data.shape[0], *meta.tile.yx)
else:
# else have 1 chunk per "sample"
_chunks = (1, *meta.tile.yx)

if data.chunksize != meta.chunks:
data = data.rechunk(meta.chunks)
if data.chunksize != _chunks:
data = data.rechunk(_chunks)
else:
assert meta.num_planes == 1
src_ydim = 0
if data.chunksize != meta.chunks:
data = data.rechunk(meta.chunks)

assert is_dask_collection(data)
encoder = _mk_tile_compressor(meta, sample_idx)

tk = tokenize(
data,
Expand All @@ -260,24 +315,26 @@ def _compress_tiles(
meta.compression,
meta.compressionargs,
)
plane_id = "" if scale_idx == 0 else f"_{scale_idx}"
plane_id += "" if sample_idx == 0 else f"@{sample_idx}"
cc_id = "" if scale_idx == 0 else f"_{scale_idx}"
cc_id += "" if meta.num_planes == 1 else f"@{sample_idx}"

name = f"compress{plane_id}-{tk}"
name = f"compress{cc_id}-{tk}"

src_data_name = data.name

def block_name(p, y, x):
def block_name(s, y, x):
if data.ndim == 2:
return (src_data_name, y, x)
if src_ydim == 0:
return (src_data_name, y, x, p)
return (src_data_name, p, y, x)
return (src_data_name, y, x, s)
if len(data.chunks[0]) == 1:
return (src_data_name, 0, y, x)
return (src_data_name, s, y, x)

dsk = {}
for i, (p, y, x) in enumerate(meta.tidx()):
block = block_name(p, y, x)
dsk[name, i] = (_compress_cog_tile, encoder, block, quote((scale_idx, p, y, x)))
for i, (s, y, x) in enumerate(meta.tidx(sample_idx)):
block = block_name(s, y, x)
dsk[name, i] = (_compress_cog_tile, encoder, block, quote((scale_idx, s, y, x)))

nparts = len(dsk)
dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[data])
Expand Down Expand Up @@ -474,11 +531,11 @@ def save_cog_with_dask(

_tiles: List["dask.bag.Bag"] = []
for scale_idx, (mm, img) in enumerate(zip(meta.flatten(), layers)):
# TODO: SYX order
tt = _compress_tiles(img, mm, scale_idx=scale_idx)
if tt.npartitions > 20:
tt = tt.repartition(npartitions=tt.npartitions // 4)
_tiles.append(tt)
for sample_idx in range(meta.num_planes):
tt = _compress_tiles(img, mm, scale_idx=scale_idx, sample_idx=sample_idx)
if tt.npartitions > 20:
tt = tt.repartition(npartitions=tt.npartitions // 4)
_tiles.append(tt)

if dst == "":
return {
Expand Down

0 comments on commit 5d23028

Please sign in to comment.