diff --git a/src/fibsem_tools/io/multiscale.py b/src/fibsem_tools/io/multiscale.py index ee7e750..eb90712 100644 --- a/src/fibsem_tools/io/multiscale.py +++ b/src/fibsem_tools/io/multiscale.py @@ -4,6 +4,7 @@ from xarray import DataArray import zarr +from fibsem_tools.io.util import normalize_chunks from fibsem_tools.metadata.cosem import ( CosemMultiscaleGroupV1, CosemMultiscaleGroupV2, @@ -20,36 +21,11 @@ multiscale_metadata_types = ["neuroglancer", "cellmap", "cosem", "ome-ngff"] -def _normalize_chunks( - arrays: Sequence[DataArray], - chunks: Union[Tuple[Tuple[int, ...], ...], Tuple[int, ...], None], -) -> Tuple[Tuple[int, ...], ...]: - if chunks is None: - result: Tuple[Tuple[int, ...]] = tuple(v.data.chunksize for v in arrays) - elif all(isinstance(c, tuple) for c in chunks): - result = chunks - else: - try: - all_ints = all((isinstance(c, int) for c in chunks)) - if all_ints: - result = (chunks,) * len(arrays) - else: - msg = f"All values in chunks must be ints. Got {chunks}" - raise ValueError(msg) - except TypeError as e: - raise e - - assert len(result) == len(arrays) - assert tuple(map(len, result)) == tuple( - x.ndim for x in arrays - ), "Number of chunks per array does not equal rank of arrays" - return result - - def multiscale_group( arrays: Sequence[DataArray], metadata_types: List[str], array_paths: Union[List[str], Literal["auto"]] = "auto", + chunks: Union[Tuple[Tuple[int, ...], ...], Literal["auto"]] = "auto", name: Optional[str] = None, **kwargs, ) -> GroupSpec: @@ -65,6 +41,13 @@ def multiscale_group( The metadata flavor(s) to use. array_paths : Sequence[str] The path for each array in storage, relative to the parent group. + chunks : Union[Tuple[Tuple[int, ...], ...], Literal["auto"]], default is "auto" + The chunks for the arrays instances. Either an explicit collection of + chunk sizes, one per array, or the string "auto". If `chunks` is "auto" and + the `data` attribute of the arrays is chunked, then each stored array + will inherit the chunks of the input arrays. If the `data` attribute + is not chunked, then each stored array will have chunks equal to the shape of + the input array. name : Optional[str] The name for the multiscale group. Only relevant for metadata flavors that support this field, e.g. ome-ngff @@ -77,6 +60,8 @@ def multiscale_group( """ if array_paths == "auto": array_paths = [f"s{idx}" for idx in range(len(arrays))] + _chunks = normalize_chunks(arrays, chunks) + group_attrs = {} array_attrs = {path: {} for path in array_paths} @@ -93,20 +78,19 @@ def multiscale_group( flave, _, version = flavor.partition("@") if flave == "neuroglancer": - g_spec = NeuroglancerN5Group.from_xarrays(arrays, **kwargs) + g_spec = NeuroglancerN5Group.from_xarrays(arrays, chunks=_chunks, **kwargs) group_attrs.update(g_spec.attrs.dict()) elif flave == "cosem": if version == "2": g_spec = CosemMultiscaleGroupV2.from_xarrays( - arrays, name=name, **kwargs + arrays, name=name, chunks=_chunks, **kwargs ) else: g_spec = CosemMultiscaleGroupV1.from_xarrays( - arrays, name=name, **kwargs + arrays, name=name, chunks=_chunks, **kwargs ) group_attrs.update(g_spec.attrs.dict()) - - for key, value in g_spec.items.items(): + for key, value in g_spec.members.items(): array_attrs[key].update(**value.attrs.dict()) elif flave == "ome-ngff": if version == "": @@ -118,17 +102,16 @@ def multiscale_group( ).dict() ] else: - raise ValueError( - f""" - Multiscale metadata type {flavor} is unknown. Try one of - {multiscale_metadata_types} - """ + msg = ( + "Multiscale metadata type {flavor} is unknown." + f"Try one of {multiscale_metadata_types}" ) + raise ValueError(msg) + members = { - path: ArraySpec.from_array(arr, attrs=array_attrs[path], **kwargs) - for arr, path in zip(arrays, array_paths) + path: ArraySpec.from_array(arr, attrs=array_attrs[path], chunks=cnks, **kwargs) + for arr, path, cnks in zip(arrays, array_paths, _chunks) } - return GroupSpec(attrs=group_attrs, members=members) diff --git a/src/fibsem_tools/io/util.py b/src/fibsem_tools/io/util.py index 774ab44..7d014d4 100644 --- a/src/fibsem_tools/io/util.py +++ b/src/fibsem_tools/io/util.py @@ -10,6 +10,8 @@ from dask import bag, delayed from typing import Protocol, runtime_checkable +from xarray import DataArray + JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] Attrs = Dict[str, JSON] PathLike = Union[Path, str] @@ -143,3 +145,51 @@ def split_by_suffix(uri: PathLike, suffixes: Sequence[str]) -> Tuple[str, str, s if protocol: pre = f"{protocol}://{pre}" return pre, post, suffix + + +def normalize_chunks( + arrays: Sequence[DataArray], + chunks: Union[Tuple[Tuple[int, ...], ...], Tuple[int, ...], Literal["auto"]], +) -> Tuple[Tuple[int, ...], ...]: + """ + Normalize a chunk specification, given a list of arrays. + + Parameters + ---------- + + arrays: Sequence[DataArray] + The list of arrays to define chunks for. + chunks: Union[Tuple[Tuple[int, ...], ...], Tuple[int, ...], Literal["auto"]] + The specification of chunks. This parameter is either a tuple of tuple of ints, + in which case it is already normalized and it passes right through, or it is + a tuple of ints, which will be "broadcast" to the length of `arrays`, or it is + the string "auto", in which case the existing chunks on the arrays with be used + if they are chunked, and otherwise chunks will be set to the shape of each + array. + + Returns + ------- + Tuple[Tuple[int, ...], ...] + """ + result: Tuple[Tuple[int, ...]] = () + if chunks == "auto": + for arr in arrays: + if arr.chunks is None: + result += (arr.shape,) + else: + result += (arr.chunks,) + elif all(isinstance(c, tuple) for c in chunks): + result = chunks + else: + all_ints = all((isinstance(c, int) for c in chunks)) + if all_ints: + result = (chunks,) * len(arrays) + else: + msg = f"All values in chunks must be ints. Got {chunks}" + raise ValueError(msg) + + assert len(result) == len(arrays) + assert tuple(map(len, result)) == tuple( + x.ndim for x in arrays + ), "Number of chunks per array does not equal rank of arrays" + return result diff --git a/src/fibsem_tools/metadata/cosem.py b/src/fibsem_tools/metadata/cosem.py index 0612488..961905b 100644 --- a/src/fibsem_tools/metadata/cosem.py +++ b/src/fibsem_tools/metadata/cosem.py @@ -1,11 +1,22 @@ -from typing import Iterable, Literal, Optional, Sequence, Union +from typing import Iterable, Literal, Optional, Sequence, Tuple, Union from pydantic import BaseModel from xarray import DataArray from pydantic_zarr import GroupSpec, ArraySpec +from fibsem_tools.io.util import normalize_chunks from fibsem_tools.metadata.transform import STTransform +def normalize_paths( + arrays: Sequence[DataArray], paths: Union[Sequence[str], Literal["auto"]] +): + if paths == "auto": + _paths = [f"s{idx}" for idx in range(len(arrays))] + else: + _paths = paths + return _paths + + class ScaleMetaV1(BaseModel): path: str transform: STTransform @@ -21,7 +32,7 @@ class MultiscaleMetaV2(BaseModel): datasets: list[str] -class COSEMGroupMetadataV1(BaseModel): +class CosemGroupMetadataV1(BaseModel): """ Multiscale metadata used by COSEM for multiscale datasets saved in N5/Zarr groups. """ @@ -36,47 +47,42 @@ def from_xarrays( name: Optional[str] = None, ): """ - Generate multiscale metadata from a list or tuple of DataArrays. + Generate multiscale metadata from a sequence of DataArrays. Parameters ---------- - arrays : list or tuple of xarray.DataArray + arrays : Sequence[xarray.DataArray] The collection of arrays from which to generate multiscale metadata. These arrays are assumed to share the same `dims` attributes, albeit with varying `coords`. - - paths : Sequence of str or the string literal 'auto', default='auto' + paths : Union[Sequence[str], Literal["auto"]] The name on the storage backend for each of the arrays in the multiscale collection. If 'auto', then names will be automatically generated using the format s0, s1, s2, etc - - name : str, optional + name : Optional[str] The name for the multiresolution collection - - Returns an instance of COSEMGroupMetadataV1 + Returns ------- - - COSEMGroupMetadata + COSEMGroupMetadataV1 """ - if paths == "auto": - paths = [f"s{idx}" for idx in range(len(arrays))] + _paths = normalize_paths(arrays, paths) multiscales = [ MultiscaleMetaV1( name=name, datasets=[ ScaleMetaV1(path=path, transform=STTransform.from_xarray(array=arr)) - for path, arr in zip(paths, arrays) + for path, arr in zip(_paths, arrays) ], ) ] return cls(name=name, multiscales=multiscales, paths=paths) -class COSEMGroupMetadataV2(BaseModel): +class CosemGroupMetadataV2(BaseModel): """ Multiscale metadata used by COSEM for multiscale datasets saved in N5/Zarr groups. """ @@ -91,38 +97,37 @@ def from_xarrays( name: Optional[str] = None, ): """ - Generate multiscale metadata from a list or tuple of DataArrays. + Generate multiscale metadata from a sequence of DataArrays. Parameters ---------- - arrays : list or tuple of xarray.DataArray + arrays : Sequence[xarray.DataArray] The collection of arrays from which to generate multiscale metadata. These arrays are assumed to share the same `dims` attributes, albeit with varying `coords`. - - paths : list or tuple of str - The name on the storage backend for each of the arrays in the multiscale - collection. - - name : str, optional - The name for the multiresolution collection - - Returns an instance of COSEMGroupMetadataV2 + paths: Union[Sequence[str], Literal["auto"]] = "auto" + The names for each of the arrays in the multiscale + collection. If set to "auto", arrays will be named automatically according + to the scheme `s0` for the largest array, s1 for second largest, and so on. + name : Optional[str], default is None. + The name for the multiresolution collection. + + Returns ------- + COSEMGroupMetadataV2 - COSEMGroupMetadata """ - if paths == "auto": - paths = [f"s{idx}" for idx in enumerate(arrays)] + + _paths = normalize_paths(arrays, paths) multiscales = [ MultiscaleMetaV2( name=name, - datasets=paths, + datasets=_paths, ) ] - return cls(name=name, multiscales=multiscales, paths=paths) + return cls(name=name, multiscales=multiscales, paths=_paths) class CosemArrayAttrs(BaseModel): @@ -135,56 +140,105 @@ class CosemMultiscaleArray(ArraySpec): @classmethod def from_xarray(cls, array: DataArray, **kwargs): attrs = CosemArrayAttrs(transform=STTransform.from_xarray(array)) - return super().from_array(array, attrs=attrs, **kwargs) + return cls.from_array(array, attrs=attrs, **kwargs) class CosemMultiscaleGroupV1(GroupSpec): - attrs: COSEMGroupMetadataV1 - items: dict[str, CosemMultiscaleArray] + attrs: CosemGroupMetadataV1 + members: dict[str, CosemMultiscaleArray] @classmethod def from_xarrays( cls, arrays: Iterable[DataArray], + chunks: Union[Tuple[Tuple[int, ...], ...], Literal["auto"]] = "auto", paths: Union[Sequence[str], Literal["auto"]] = "auto", name: Optional[str] = None, **kwargs, ): + """ + Convert a collection of DataArray to a GroupSpec with CosemMultiscaleV1 metadata + + Parameters + ---------- - if paths == "auto": - paths = [f"s{idx}" for idx in range(len(arrays))] + arrays: Iterable[DataArray] + The arrays comprising the multiscale image. + chunks : Union[Tuple[Tuple[int, ...], ...], Literal["auto"]], default is "auto" + The chunks for the `ArraySpec` instances. Either an explicit collection of + chunk sizes, one per array, or the string "auto". If `chunks` is "auto" and + the `data` attribute of the arrays is chunked, then each ArraySpec + instance will inherit the chunks of the arrays. If the `data` attribute + is not chunked, then each `ArraySpec` will have chunks equal to the shape of + the source array. + paths: Union[Sequence[str], Literal["auto"]] = "auto" + The names for each of the arrays in the multiscale + collection. If set to "auto", arrays will be named automatically according + to the scheme `s0` for the largest array, s1 for second largest, and so on. + name: Optional[str], default is None + The name for the multiscale collection. + **kwargs: + Additional keyword arguments that will be passed to the `ArraySpec` + constructor. + """ + _paths = normalize_paths(arrays, paths) - attrs = COSEMGroupMetadataV1.from_xarrays(arrays, paths, name) + _chunks = normalize_chunks(arrays, chunks) + attrs = CosemGroupMetadataV1.from_xarrays(arrays, _paths, name) array_specs = { - k: CosemMultiscaleArray.from_xarray(arr, **kwargs) - for k, arr in zip(paths, arrays) + key: CosemMultiscaleArray.from_xarray(arr, chunks=cnks, **kwargs) + for arr, cnks, key in zip(arrays, _chunks, _paths) } - return cls(attrs=attrs, items=array_specs) + return cls(attrs=attrs, members=array_specs) class CosemMultiscaleGroupV2(GroupSpec): - attrs: COSEMGroupMetadataV2 - items: dict[str, ArraySpec[CosemArrayAttrs]] + attrs: CosemGroupMetadataV2 + members: dict[str, ArraySpec[CosemArrayAttrs]] @classmethod def from_xarrays( cls, arrays: Iterable[DataArray], + chunks: Union[Tuple[Tuple[int, ...]], Literal["auto"]] = "auto", paths: Union[Sequence[str], Literal["auto"]] = "auto", name: Optional[str] = None, **kwargs, ): + """ + Convert a collection of DataArray to a GroupSpec with CosemMultiscaleV2 metadata + + Parameters + ---------- - if paths == "auto": - paths = [f"s{idx}" for idx in range(len(arrays))] + arrays: Iterable[DataArray] + The arrays comprising the multiscale image. + chunks : Union[Tuple[Tuple[int, ...], ...], Literal["auto"]], default is "auto" + The chunks for the `ArraySpec` instances. Either an explicit collection of + chunk sizes, one per array, or the string "auto". If `chunks` is "auto" and + the `data` attribute of the arrays is chunked, then each ArraySpec + instance will inherit the chunks of the arrays. If the `data` attribute + is not chunked, then each `ArraySpec` will have chunks equal to the shape of + the source array. + paths: Union[Sequence[str], Literal["auto"]] = "auto" + The names for each of the arrays in the multiscale + collection. + name: Optional[str], default is None + The name for the multiscale collection. + **kwargs: + Additional keyword arguments that will be passed to the `ArraySpec` + constructor. + """ - attrs = COSEMGroupMetadataV2.from_xarrays(arrays, paths, name) + _paths = normalize_paths(arrays, paths) + _chunks = normalize_chunks(arrays, chunks) + attrs = CosemGroupMetadataV2.from_xarrays(arrays, _paths, name) array_specs = { - k: CosemMultiscaleArray.from_xarray(arr, **kwargs) - for k, arr in zip(paths, arrays) + key: CosemMultiscaleArray.from_xarray(arr, chunks=cnks, **kwargs) + for arr, cnks, key in zip(arrays, _chunks, _paths) } - return cls(attrs=attrs, items=array_specs) + return cls(attrs=attrs, members=array_specs) diff --git a/src/fibsem_tools/metadata/neuroglancer.py b/src/fibsem_tools/metadata/neuroglancer.py index 4d53610..e50f6a5 100644 --- a/src/fibsem_tools/metadata/neuroglancer.py +++ b/src/fibsem_tools/metadata/neuroglancer.py @@ -1,9 +1,10 @@ -from typing import Iterable, List, Sequence +from typing import Iterable, List, Literal, Sequence, Union import numpy as np from pydantic import BaseModel, PositiveInt, ValidationError, validator from xarray import DataArray from pydantic_zarr.core import GroupSpec, ArraySpec +from fibsem_tools.io.util import normalize_chunks from fibsem_tools.metadata.transform import STTransform @@ -126,11 +127,17 @@ def validate_members(cls, v: dict[str, ArraySpec]): @classmethod def from_xarrays( - cls, arrays: Iterable[DataArray], chunks: tuple[int, ...], **kwargs + cls, + arrays: Iterable[DataArray], + chunks: Union[tuple[tuple[int, ...]], Literal["auto"]], + **kwargs, ) -> "NeuroglancerN5Group": + + _chunks = normalize_chunks(arrays, chunks) + array_specs = { - f"s{idx}": ArraySpec.from_array(arr, chunks=chunks, **kwargs) - for idx, arr in enumerate(arrays) + f"s{idx}": ArraySpec.from_array(arr, chunks=cnks, **kwargs) + for idx, arr, cnks in zip(range(len(arrays)), arrays, _chunks) } attrs = NeuroglancerN5GroupMetadata.from_xarrays(arrays) return cls(attrs=attrs, members=array_specs) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index d3787e2..e3f4def 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -3,11 +3,11 @@ from xarray import DataArray from fibsem_tools.io.xr import stt_from_array from fibsem_tools.metadata.cosem import ( - COSEMGroupMetadataV1, + CosemGroupMetadataV1, CosemMultiscaleGroupV1, CosemMultiscaleGroupV2, MultiscaleMetaV1, - COSEMGroupMetadataV2, + CosemGroupMetadataV2, MultiscaleMetaV2, ) from fibsem_tools.metadata.neuroglancer import ( @@ -99,9 +99,9 @@ def test_cosem(version: Literal["v1", "v2"]): paths = ("s0", "s1") if version == "v1": - g_meta = COSEMGroupMetadataV1.from_xarrays(multi, paths=paths, name="data") + g_meta = CosemGroupMetadataV1.from_xarrays(multi, paths=paths, name="data") - assert g_meta == COSEMGroupMetadataV1( + assert g_meta == CosemGroupMetadataV1( multiscales=[ MultiscaleMetaV1( name="data", @@ -114,13 +114,13 @@ def test_cosem(version: Literal["v1", "v2"]): ) spec = CosemMultiscaleGroupV1.from_xarrays(multi, name="data") assert spec.attrs == g_meta - assert tuple(spec.items.keys()) == paths + assert tuple(spec.members.keys()) == paths else: - g_meta = COSEMGroupMetadataV2.from_xarrays(multi, paths=paths, name="data") - assert g_meta == COSEMGroupMetadataV2( + g_meta = CosemGroupMetadataV2.from_xarrays(multi, paths=paths, name="data") + assert g_meta == CosemGroupMetadataV2( multiscales=[MultiscaleMetaV2(name="data", datasets=paths)] ) spec = CosemMultiscaleGroupV2.from_xarrays(multi, name="data") assert spec.attrs == g_meta - assert tuple(spec.items.keys()) == paths + assert tuple(spec.members.keys()) == paths diff --git a/tests/test_multiscale.py b/tests/test_multiscale.py index badcaa3..1fe6beb 100644 --- a/tests/test_multiscale.py +++ b/tests/test_multiscale.py @@ -51,6 +51,5 @@ def test_multiscale_storage(temp_zarr, metadata_types: Tuple[str, ...]): array_urls = [f"{temp_zarr}/{ap}" for ap in array_paths] da.compute(store_blocks(multi, [access(a_url, mode="a") for a_url in array_urls])) - assert dict(group.attrs) == g_spec.attrs assert all(read(a).chunks == chunks for a in array_urls) diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..3bb4730 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,18 @@ +import pytest +import dask.array as da +from xarray import DataArray +from fibsem_tools.io.util import normalize_chunks + + +@pytest.mark.parametrize("chunks", ("auto", (3, 3, 3), ((3, 3, 3), (3, 3, 3)))) +def test_normalize_chunks(chunks): + arrays = DataArray(da.zeros((10, 10, 10), chunks=(4, 4, 4))), DataArray( + da.zeros((5, 5, 5), chunks=(2, 2, 2)) + ) + observed = normalize_chunks(arrays, chunks) + if chunks == "auto": + assert observed == (arrays[0].chunks, arrays[1].chunks) + elif isinstance(chunks[0], int): + assert observed == (chunks,) * len(arrays) + else: + assert observed == chunks diff --git a/tests/test_zarr.py b/tests/test_zarr.py index 790c3c0..05ee304 100644 --- a/tests/test_zarr.py +++ b/tests/test_zarr.py @@ -105,7 +105,7 @@ def test_read_datatree(temp_zarr, attrs, coords, use_dask, name): g_spec = multiscale_group( arrays=tuple(data.values()), array_paths=tuple(data.keys()), - chunks=(64, 64, 64), + chunks=((64, 64, 64),) * len(data), metadata_types=["cosem"], ) g_spec.attrs.update(**_attrs)