diff --git a/src/fibsem_tools/metadata/split_annotation.py b/src/fibsem_tools/metadata/split_annotation.py index af6c9d0..5172f35 100755 --- a/src/fibsem_tools/metadata/split_annotation.py +++ b/src/fibsem_tools/metadata/split_annotation.py @@ -1,5 +1,5 @@ import os -from typing import Dict +from typing import Dict, Literal, Tuple, TypeVar, Union import click from xarray import DataArray from fibsem_tools import read_xarray @@ -20,47 +20,51 @@ CropGroupAttrs, wrap_attributes ) +from xarray_multiscale import multiscale, windowed_mode ome_adapters = get_adapters("0.4") -out_chunks = (256,) * 3 +out_chunks = (64,) * 3 annotation_type = SemanticSegmentation(encoding={"absent": 0, "present": 1}) - def create_spec( - data: DataArray, crop_name: str, array_name: str, class_encoding: Dict[str, int] + data: dict[str, DataArray], + crop_name: str, + class_encoding: Dict[str, int] ): - ome_meta = ome_adapters.multiscale_metadata([data], [array_name]) + ome_meta = ome_adapters.multiscale_metadata(tuple(data.values()), tuple(data.keys())) annotation_group_specs: dict[str, GroupSpec] = {} for class_name, value in class_encoding.items(): - data_unique = (data == value).astype("uint8") - num_present = int(data_unique.sum()) - num_absent = data_unique.size - num_present - hist = {"absent": num_absent} - - annotation_group_attrs = AnnotationGroupAttrs( - class_name=class_name, description="", annotation_type=annotation_type - ) - - annotation_array_attrs = AnnotationArrayAttrs( - class_name=class_name, histogram=hist, annotation_type=annotation_type - ) - label_array_spec = ArraySpec.from_array( - data_unique, - chunks=out_chunks, - compressor=Blosc(cname="zstd"), - attrs={"cellmap": {"annotation": annotation_array_attrs}}, - ) + label_array_specs = {} + for array_name, array_data in data.items(): + data_unique = (array_data == value).astype("uint8") + num_present = int(data_unique.sum()) + num_absent = data_unique.size - num_present + hist = {"absent": num_absent} + + annotation_group_attrs = AnnotationGroupAttrs( + class_name=class_name, description="", annotation_type=annotation_type + ) + + annotation_array_attrs = AnnotationArrayAttrs( + class_name=class_name, histogram=hist, annotation_type=annotation_type + ) + label_array_specs[array_name] = ArraySpec.from_array( + data_unique, + chunks=out_chunks, + compressor=Blosc(cname="zstd"), + attrs=wrap_attributes(annotation_array_attrs).dict(), + ) annotation_group_specs[class_name] = GroupSpec( attrs={ - "cellmap": {"annotation": annotation_group_attrs}, + **wrap_attributes(annotation_group_attrs).dict(), "multiscales": [ome_meta.dict()], }, - members={array_name: label_array_spec}, + members=label_array_specs, ) crop_attrs = CropGroupAttrs( @@ -97,9 +101,18 @@ def guess_format(path: str): def split_annotations( - source: str, dest: str, crop_name: str, class_encoding: Dict[str, int] + source: str, + dest: str, + crop_name: str, + class_encoding: Dict[str, int], + chunks: Union[Literal['auto'], Tuple[Tuple[int, ...],...]] = 'auto' ) -> zarr.Group: + if chunks == 'auto': + out_chunks = (64,64,64) + else: + out_chunks = chunks + pre, post, _ = split_by_suffix(dest, (".zarr",)) # fail fast if there's already a group there @@ -123,9 +136,8 @@ def split_annotations( multi = {m.name: m for m in multiscale(data, windowed_mode, (2,2,2), chunks=out_chunks)} spec = create_spec( - data=data, + data=multi, crop_name=crop_name, - array_name=array_name, class_encoding=class_encoding, ) @@ -134,13 +146,14 @@ def split_annotations( ) for class_name, value in class_encoding.items(): - data_unique = np.array((data == value).astype("uint8")) - arr = zarr.Array( - store=crop_group.store, - path=os.path.join(crop_group.path, class_name, array_name), - write_empty_chunks=False, - ) - arr[:] = data_unique + for array_name, data in multi.items(): + data_unique = np.array((data == value).astype("uint8")) + arr = zarr.Array( + store=crop_group.store, + path=os.path.join(crop_group.path, class_name, array_name), + write_empty_chunks=False, + ) + arr[:] = data_unique return crop_group @@ -151,7 +164,7 @@ def split_annotations( @click.argument("name", type=click.STRING) def cli(source, dest, name): class_encoding = class_encoding_from_airtable_by_image(name) - split_annotations(source, dest, name, class_encoding) + split_annotations(source, dest, name, class_encoding, chunks) if __name__ == "__main__": cli()