Skip to content

Commit

Permalink
Compute and render pixel statistics to COG
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Oct 17, 2023
1 parent 41d45ee commit 40fdccd
Showing 1 changed file with 113 additions and 6 deletions.
119 changes: 113 additions & 6 deletions odc/geo/cog/_tifffile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""
Write Cloud Optimized GeoTIFFs from xarrays.
"""
from __future__ import annotations

import itertools
from functools import partial
from io import BytesIO
Expand All @@ -27,11 +29,92 @@
)

if TYPE_CHECKING:
import dask.array
import dask.bag
from dask.delayed import Delayed

# pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements,too-many-instance-attributes


def _render_gdal_metadata(
band_stats: list[dict[str, float]] | dict[str, float],
precision: int = 10,
pad: int = 0,
eol: str = "",
) -> str:
def _item(sample: int, stats: dict[str, float]) -> str:
return eol.join(
[
f'<Item name="STATISTICS_{k.upper()}" sample="{sample:d}">{v:{pad}.{precision}f}</Item>'
for k, v in stats.items()
]
)

if isinstance(band_stats, dict):
band_stats = [band_stats]

body = eol.join([_item(sample, stats) for sample, stats in enumerate(band_stats)])
return eol.join(["<GDALMetadata>", body, "</GDALMetadata>"])


def _unwrap_stats(stats, ndim):
if ndim == 2:
return [{k: float(v) for k, v in stats.items()}]

n = {len(v) for v in stats.values()}.pop()
return [{k: v[idx] for k, v in stats.items()} for idx in range(n)]


def _stats_from_layer(
pix: "dask.array.Array", nodata=None, yaxis: int = 0
) -> "Delayed":
# pylint: disable=import-outside-toplevel
from dask import array as da
from dask import delayed

unwrap = delayed(_unwrap_stats, pure=True, traverse=True)

axis = (yaxis, yaxis + 1)
npix = pix.shape[yaxis] * pix.shape[yaxis + 1]
if nodata is not None:
dd = da.ma.masked_equal(pix, nodata)
return unwrap(
{
"minimum": dd.min(axis=axis),
"maximum": dd.max(axis=axis),
"mean": dd.mean(axis=axis),
"stddev": dd.std(axis=axis),
"valid_percent": da.isfinite(dd).sum(axis=axis) * (100 / npix),
},
pix.ndim,
)

if pix.dtype.kind == "f":
dd = pix
return unwrap(
{
"minimum": da.nanmin(dd, axis=axis),
"maximum": da.nanmax(dd, axis=axis),
"mean": da.nanmean(dd, axis=axis),
"stddev": da.nanstd(dd, axis=axis),
"valid_percent": da.isfinite(dd).sum(axis=axis) * (100 / npix),
},
pix.ndim,
)

dd = pix
return unwrap(
{
"minimum": dd.min(axis=axis),
"maximum": dd.max(axis=axis),
"mean": dd.mean(axis=axis),
"stddev": dd.std(axis=axis),
"valid_percent": da.isfinite(dd).sum(axis=axis) * (100 / npix),
},
pix.ndim,
)


def _make_empty_cog(
shape: Tuple[int, ...],
dtype: Any,
Expand Down Expand Up @@ -379,25 +462,35 @@ def _extract_tile_info(
return tile_info


def _build_hdr(
tiles: List[Tuple[int, Tuple[int, int, int, int]]], meta: CogMeta, hdr0: bytes
def _patch_hdr(
tiles: List[Tuple[int, Tuple[int, int, int, int]]],
meta: CogMeta,
hdr0: bytes,
stats: Optional[list[dict[str, float]]] = None,
) -> bytes:
# pylint: disable=import-outside-toplevel
from tifffile import TiffFile

_tiles = [(*idx, sz) for sz, idx in tiles]
tile_info = _extract_tile_info(meta, _tiles, len(hdr0))
tile_info = _extract_tile_info(meta, _tiles, 0)

_bio = BytesIO(hdr0)
with TiffFile(_bio, mode="r+", name=":mem:") as tr:
assert len(tile_info) == len(tr.pages)
if stats is not None:
md_tag = tr.pages[0].tags.get(42112, None)
assert md_tag is not None
gdal_metadata = _render_gdal_metadata(stats, precision=6)
md_tag.overwrite(gdal_metadata)

hdr_sz = len(_bio.getbuffer())

# 324 -- offsets
# 325 -- byte counts
for info, page in zip(tile_info, tr.pages):
tags = page.tags
offsets, lengths = info
tags[324].overwrite(offsets)
tags[324].overwrite([off + hdr_sz for off in offsets])
tags[325].overwrite(lengths)

return bytes(_bio.getbuffer())
Expand Down Expand Up @@ -490,6 +583,7 @@ def save_cog_with_dask(
overview_resampling: Union[int, str] = "nearest",
aws: Optional[Dict[str, Any]] = None,
client: Any = None,
stats: bool | int = True,
**kw,
) -> Any:
# pylint: disable=import-outside-toplevel
Expand All @@ -513,6 +607,8 @@ def save_cog_with_dask(
if isinstance(blocksize, Unset):
blocksize = [data_chunks, int(max(*data_chunks) // 2)]

gdal_metadata = None if stats is False else ""

meta, hdr0 = _make_empty_cog(
xx.shape,
xx.dtype,
Expand All @@ -523,12 +619,22 @@ def save_cog_with_dask(
blocksize=blocksize,
bigtiff=bigtiff,
nodata=xx_odc.nodata,
gdal_metadata=gdal_metadata,
**kw,
)
hdr0 = bytes(hdr0)

layers = _pyramids_from_cog_metadata(xx, meta, resampling=overview_resampling)

if stats is True:
stats = len(layers) // 2

_stats: "Delayed" | None = None
if stats is not False:
_stats = _stats_from_layer(
layers[stats].data, nodata=xx_odc.nodata, yaxis=xx_odc.ydim
)

_tiles: List["dask.bag.Bag"] = []
for scale_idx, (mm, img) in enumerate(zip(meta.flatten(), layers)):
for sample_idx in range(meta.num_planes):
Expand All @@ -543,6 +649,7 @@ def save_cog_with_dask(
"hdr0": hdr0,
"tiles": _tiles,
"layers": layers,
"_stats": _stats,
}

bucket, key = s3_parse_url(dst)
Expand All @@ -569,8 +676,8 @@ def save_cog_with_dask(
s3_sink.cancel("all")
return s3_sink.upload(
tiles_write_order,
mk_header=_build_hdr,
user_kw={"meta": meta, "hdr0": hdr0},
mk_header=_patch_hdr,
user_kw={"meta": meta, "hdr0": hdr0, "stats": _stats},
client=client,
**upload_params,
)
Expand Down

0 comments on commit 40fdccd

Please sign in to comment.