Skip to content

Commit

Permalink
Merge pull request #59 from janelia-cosem/multiscale_metadata_fix
Browse files Browse the repository at this point in the history
feat: overhaul handling of chunks for multiscale GroupSpecs.
  • Loading branch information
d-v-b authored Sep 16, 2023
2 parents 107fb2d + 7aebe25 commit 1878d4c
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 102 deletions.
61 changes: 22 additions & 39 deletions src/fibsem_tools/io/multiscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from xarray import DataArray

import zarr
from fibsem_tools.io.util import normalize_chunks
from fibsem_tools.metadata.cosem import (
CosemMultiscaleGroupV1,
CosemMultiscaleGroupV2,
Expand All @@ -20,36 +21,11 @@
multiscale_metadata_types = ["neuroglancer", "cellmap", "cosem", "ome-ngff"]


def _normalize_chunks(
arrays: Sequence[DataArray],
chunks: Union[Tuple[Tuple[int, ...], ...], Tuple[int, ...], None],
) -> Tuple[Tuple[int, ...], ...]:
if chunks is None:
result: Tuple[Tuple[int, ...]] = tuple(v.data.chunksize for v in arrays)
elif all(isinstance(c, tuple) for c in chunks):
result = chunks
else:
try:
all_ints = all((isinstance(c, int) for c in chunks))
if all_ints:
result = (chunks,) * len(arrays)
else:
msg = f"All values in chunks must be ints. Got {chunks}"
raise ValueError(msg)
except TypeError as e:
raise e

assert len(result) == len(arrays)
assert tuple(map(len, result)) == tuple(
x.ndim for x in arrays
), "Number of chunks per array does not equal rank of arrays"
return result


def multiscale_group(
arrays: Sequence[DataArray],
metadata_types: List[str],
array_paths: Union[List[str], Literal["auto"]] = "auto",
chunks: Union[Tuple[Tuple[int, ...], ...], Literal["auto"]] = "auto",
name: Optional[str] = None,
**kwargs,
) -> GroupSpec:
Expand All @@ -65,6 +41,13 @@ def multiscale_group(
The metadata flavor(s) to use.
array_paths : Sequence[str]
The path for each array in storage, relative to the parent group.
chunks : Union[Tuple[Tuple[int, ...], ...], Literal["auto"]], default is "auto"
The chunks for the arrays instances. Either an explicit collection of
chunk sizes, one per array, or the string "auto". If `chunks` is "auto" and
the `data` attribute of the arrays is chunked, then each stored array
will inherit the chunks of the input arrays. If the `data` attribute
is not chunked, then each stored array will have chunks equal to the shape of
the input array.
name : Optional[str]
The name for the multiscale group. Only relevant for metadata flavors that
support this field, e.g. ome-ngff
Expand All @@ -77,6 +60,8 @@ def multiscale_group(
"""
if array_paths == "auto":
array_paths = [f"s{idx}" for idx in range(len(arrays))]
_chunks = normalize_chunks(arrays, chunks)

group_attrs = {}
array_attrs = {path: {} for path in array_paths}

Expand All @@ -93,20 +78,19 @@ def multiscale_group(
flave, _, version = flavor.partition("@")

if flave == "neuroglancer":
g_spec = NeuroglancerN5Group.from_xarrays(arrays, **kwargs)
g_spec = NeuroglancerN5Group.from_xarrays(arrays, chunks=_chunks, **kwargs)
group_attrs.update(g_spec.attrs.dict())
elif flave == "cosem":
if version == "2":
g_spec = CosemMultiscaleGroupV2.from_xarrays(
arrays, name=name, **kwargs
arrays, name=name, chunks=_chunks, **kwargs
)
else:
g_spec = CosemMultiscaleGroupV1.from_xarrays(
arrays, name=name, **kwargs
arrays, name=name, chunks=_chunks, **kwargs
)
group_attrs.update(g_spec.attrs.dict())

for key, value in g_spec.items.items():
for key, value in g_spec.members.items():
array_attrs[key].update(**value.attrs.dict())
elif flave == "ome-ngff":
if version == "":
Expand All @@ -118,17 +102,16 @@ def multiscale_group(
).dict()
]
else:
raise ValueError(
f"""
Multiscale metadata type {flavor} is unknown. Try one of
{multiscale_metadata_types}
"""
msg = (
"Multiscale metadata type {flavor} is unknown."
f"Try one of {multiscale_metadata_types}"
)
raise ValueError(msg)

members = {
path: ArraySpec.from_array(arr, attrs=array_attrs[path], **kwargs)
for arr, path in zip(arrays, array_paths)
path: ArraySpec.from_array(arr, attrs=array_attrs[path], chunks=cnks, **kwargs)
for arr, path, cnks in zip(arrays, array_paths, _chunks)
}

return GroupSpec(attrs=group_attrs, members=members)


Expand Down
50 changes: 50 additions & 0 deletions src/fibsem_tools/io/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dask import bag, delayed
from typing import Protocol, runtime_checkable

from xarray import DataArray

JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None]
Attrs = Dict[str, JSON]
PathLike = Union[Path, str]
Expand Down Expand Up @@ -143,3 +145,51 @@ def split_by_suffix(uri: PathLike, suffixes: Sequence[str]) -> Tuple[str, str, s
if protocol:
pre = f"{protocol}://{pre}"
return pre, post, suffix


def normalize_chunks(
arrays: Sequence[DataArray],
chunks: Union[Tuple[Tuple[int, ...], ...], Tuple[int, ...], Literal["auto"]],
) -> Tuple[Tuple[int, ...], ...]:
"""
Normalize a chunk specification, given a list of arrays.
Parameters
----------
arrays: Sequence[DataArray]
The list of arrays to define chunks for.
chunks: Union[Tuple[Tuple[int, ...], ...], Tuple[int, ...], Literal["auto"]]
The specification of chunks. This parameter is either a tuple of tuple of ints,
in which case it is already normalized and it passes right through, or it is
a tuple of ints, which will be "broadcast" to the length of `arrays`, or it is
the string "auto", in which case the existing chunks on the arrays with be used
if they are chunked, and otherwise chunks will be set to the shape of each
array.
Returns
-------
Tuple[Tuple[int, ...], ...]
"""
result: Tuple[Tuple[int, ...]] = ()
if chunks == "auto":
for arr in arrays:
if arr.chunks is None:
result += (arr.shape,)
else:
result += (arr.chunks,)
elif all(isinstance(c, tuple) for c in chunks):
result = chunks
else:
all_ints = all((isinstance(c, int) for c in chunks))
if all_ints:
result = (chunks,) * len(arrays)
else:
msg = f"All values in chunks must be ints. Got {chunks}"
raise ValueError(msg)

assert len(result) == len(arrays)
assert tuple(map(len, result)) == tuple(
x.ndim for x in arrays
), "Number of chunks per array does not equal rank of arrays"
return result
Loading

0 comments on commit 1878d4c

Please sign in to comment.