Skip to content

Commit

Permalink
feat: add multiscale to annotation conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Oct 16, 2023
1 parent 0b1997c commit 50d74b7
Showing 1 changed file with 49 additions and 36 deletions.
85 changes: 49 additions & 36 deletions src/fibsem_tools/metadata/split_annotation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, Literal, Tuple, TypeVar, Union
import click
from xarray import DataArray
from fibsem_tools import read_xarray
Expand All @@ -20,47 +20,51 @@
CropGroupAttrs,
wrap_attributes
)
from xarray_multiscale import multiscale, windowed_mode

ome_adapters = get_adapters("0.4")
out_chunks = (256,) * 3
out_chunks = (64,) * 3

annotation_type = SemanticSegmentation(encoding={"absent": 0, "present": 1})


def create_spec(
data: DataArray, crop_name: str, array_name: str, class_encoding: Dict[str, int]
data: dict[str, DataArray],
crop_name: str,
class_encoding: Dict[str, int]
):

ome_meta = ome_adapters.multiscale_metadata([data], [array_name])
ome_meta = ome_adapters.multiscale_metadata(tuple(data.values()), tuple(data.keys()))

annotation_group_specs: dict[str, GroupSpec] = {}

for class_name, value in class_encoding.items():
data_unique = (data == value).astype("uint8")
num_present = int(data_unique.sum())
num_absent = data_unique.size - num_present
hist = {"absent": num_absent}

annotation_group_attrs = AnnotationGroupAttrs(
class_name=class_name, description="", annotation_type=annotation_type
)

annotation_array_attrs = AnnotationArrayAttrs(
class_name=class_name, histogram=hist, annotation_type=annotation_type
)
label_array_spec = ArraySpec.from_array(
data_unique,
chunks=out_chunks,
compressor=Blosc(cname="zstd"),
attrs={"cellmap": {"annotation": annotation_array_attrs}},
)
label_array_specs = {}
for array_name, array_data in data.items():
data_unique = (array_data == value).astype("uint8")
num_present = int(data_unique.sum())
num_absent = data_unique.size - num_present
hist = {"absent": num_absent}

annotation_group_attrs = AnnotationGroupAttrs(
class_name=class_name, description="", annotation_type=annotation_type
)

annotation_array_attrs = AnnotationArrayAttrs(
class_name=class_name, histogram=hist, annotation_type=annotation_type
)
label_array_specs[array_name] = ArraySpec.from_array(
data_unique,
chunks=out_chunks,
compressor=Blosc(cname="zstd"),
attrs=wrap_attributes(annotation_array_attrs).dict(),
)

annotation_group_specs[class_name] = GroupSpec(
attrs={
"cellmap": {"annotation": annotation_group_attrs},
**wrap_attributes(annotation_group_attrs).dict(),
"multiscales": [ome_meta.dict()],
},
members={array_name: label_array_spec},
members=label_array_specs,
)

crop_attrs = CropGroupAttrs(
Expand Down Expand Up @@ -97,9 +101,18 @@ def guess_format(path: str):


def split_annotations(
source: str, dest: str, crop_name: str, class_encoding: Dict[str, int]
source: str,
dest: str,
crop_name: str,
class_encoding: Dict[str, int],
chunks: Union[Literal['auto'], Tuple[Tuple[int, ...],...]] = 'auto'
) -> zarr.Group:

if chunks == 'auto':
out_chunks = (64,64,64)
else:
out_chunks = chunks

pre, post, _ = split_by_suffix(dest, (".zarr",))
# fail fast if there's already a group there

Expand All @@ -123,9 +136,8 @@ def split_annotations(
multi = {m.name: m for m in multiscale(data, windowed_mode, (2,2,2), chunks=out_chunks)}

spec = create_spec(
data=data,
data=multi,
crop_name=crop_name,
array_name=array_name,
class_encoding=class_encoding,
)

Expand All @@ -134,13 +146,14 @@ def split_annotations(
)

for class_name, value in class_encoding.items():
data_unique = np.array((data == value).astype("uint8"))
arr = zarr.Array(
store=crop_group.store,
path=os.path.join(crop_group.path, class_name, array_name),
write_empty_chunks=False,
)
arr[:] = data_unique
for array_name, data in multi.items():
data_unique = np.array((data == value).astype("uint8"))
arr = zarr.Array(
store=crop_group.store,
path=os.path.join(crop_group.path, class_name, array_name),
write_empty_chunks=False,
)
arr[:] = data_unique

return crop_group

Expand All @@ -151,7 +164,7 @@ def split_annotations(
@click.argument("name", type=click.STRING)
def cli(source, dest, name):
class_encoding = class_encoding_from_airtable_by_image(name)
split_annotations(source, dest, name, class_encoding)
split_annotations(source, dest, name, class_encoding, chunks)

if __name__ == "__main__":
cli()

0 comments on commit 50d74b7

Please sign in to comment.