Skip to content

Commit

Permalink
feat: add more performant versions of fsstores, and add a new way to …
Browse files Browse the repository at this point in the history
…save arrays that don't use the dask array formalism
  • Loading branch information
d-v-b committed Sep 15, 2023
1 parent 899abc9 commit d8c9108
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 21 deletions.
78 changes: 59 additions & 19 deletions src/fibsem_tools/io/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import fuse
from dask.utils import is_arraylike, parse_bytes
from dask.utils import parse_bytes
from zarr.util import normalize_chunks as normalize_chunksize
from numpy.typing import NDArray, DTypeLike
import random
from fibsem_tools.io.core import access, read
from fibsem_tools.io.zarr import are_chunks_aligned
from dask import delayed

random.seed(0)

Expand Down Expand Up @@ -65,33 +66,49 @@ def sequential_rechunk(


@backoff.on_exception(backoff.expo, (ServerDisconnectedError, OSError))
def store_chunk(x: NDArray[Any], out: Any, index: Tuple[slice, ...]) -> Literal[0]:
def store_chunk(
target: NDArray[Any], key: Tuple[slice, ...], value: NDArray[Any]
) -> Literal[0]:
"""
A function inserted in a Dask graph for storing a chunk.
Parameters
----------
x: array-like
An array (potentially a NumPy one)
out: array-like
Where to store results to.
index: slice-like
Where to store result from ``x`` in ``out``.
target: NDArray
Where to store the value.
key: Tuple[slice, ...]
The location in the array for the value.
value: NDArray
The value to be stored.
Examples
--------
"""
target[key] = value
return 0

>>> a = np.ones((5, 6))
>>> b = np.empty(a.shape)
>>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)

@backoff.on_exception(backoff.expo, (ServerDisconnectedError, OSError))
def store_value(
target: NDArray[Any], key: Tuple[slice, ...], value: NDArray[Any]
) -> Literal[0]:
"""
A function inserted in a Dask graph for storing a chunk.
if is_arraylike(x):
out[index] = x
else:
out[index] = np.asanyarray(x)
Parameters
----------
target: NDArray
Where to store the value.
key: Tuple[slice, ...]
The location in the array for the value.
value: NDArray
The value to be stored.
return 0
Examples
--------
"""
target[key] = value
return key


def ndwrapper(func: Callable[[Any], Any], ndim: int, *args: Any, **kwargs: Any):
Expand Down Expand Up @@ -124,14 +141,14 @@ def write_blocks(source, target, region: Optional[Tuple[slice, ...]]) -> da.Arra
dsk = {}
chunks = tuple((1,) * s for s in source.blocks.shape)

for slice, key in zip(slices, flatten(source.__dask_keys__())):
for slce, key in zip(slices, flatten(source.__dask_keys__())):
dsk[(store_name,) + key[1:]] = (
ndwrapper,
store_chunk,
source.ndim,
key,
target,
slice,
slce,
key,
)

layers[store_name] = dsk
Expand Down Expand Up @@ -181,6 +198,29 @@ def store_blocks(sources, targets, regions: Optional[slice] = None) -> List[da.A
return result


def write_blocks_delayed(
source, target, region: Optional[Tuple[slice, ...]] = None
) -> Sequence[Any]:
"""
Return a collection fo task each task returns the result of writing
each chunk of `source` to `target`.
"""

# handle xarray
if hasattr(source, "data") and isinstance(source.data, da.Array):
source = source.data

slices = slices_from_chunks(source.chunks)
if region:
slices = [fuse_slice(region, slc) for slc in slices]
blocks_flat = source.blocks.ravel()
assert len(slices) == len(blocks_flat)
return [
delayed(store_value)(target, slce, block)
for slce, block in zip(slices, blocks_flat)
]


def ensure_minimum_chunksize(array, chunksize):
old_chunks = np.array(array.chunksize)
new_chunks = old_chunks.copy()
Expand Down
51 changes: 49 additions & 2 deletions src/fibsem_tools/io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fibsem_tools.io.xr import stt_coord
from fibsem_tools.metadata.transform import STTransform
from xarray_ome_ngff.registry import get_adapters
from zarr.errors import ReadOnlyError

ureg = pint.UnitRegistry()

Expand All @@ -30,11 +31,53 @@
# default axis order of raw n5 spatial metadata
# is x,y,z
N5_AXES_3D = ZARR_AXES_3D[::-1]
DEFAULT_ZARR_STORE = FSStore
DEFAULT_N5_STORE = zarr.N5FSStore
logger = logging.getLogger(__name__)


class FSStorePatched(FSStore):
"""
Patch delitems to delete "blind", i.e. without checking if to-be-deleted keys exist.
This is temporary and should be removed when
https://github.com/zarr-developers/zarr-python/issues/1336
is resolved.
"""

def delitems(self, keys):
if self.mode == "r":
raise ReadOnlyError()
try: # should much faster
nkeys = [self._normalize_key(key) for key in keys]
# rm errors if you pass an empty collection
self.map.delitems(nkeys)
except FileNotFoundError:
nkeys = [self._normalize_key(key) for key in keys if key in self]
# rm errors if you pass an empty collection
if len(nkeys) > 0:
self.map.delitems(nkeys)


class N5FSStorePatched(zarr.N5FSStore):
"""
Patch delitems to delete "blind", i.e. without checking if to-be-deleted keys exist.
This is temporary and should be removed when
https://github.com/zarr-developers/zarr-python/issues/1336
is resolved.
"""

def delitems(self, keys):
if self.mode == "r":
raise ReadOnlyError()
try: # should much faster
nkeys = [self._normalize_key(key) for key in keys]
# rm errors if you pass an empty collection
self.map.delitems(nkeys)
except FileNotFoundError:
nkeys = [self._normalize_key(key) for key in keys if key in self]
# rm errors if you pass an empty collection
if len(nkeys) > 0:
self.map.delitems(nkeys)


def get_arrays(obj: Any) -> Tuple[zarr.Array]:
result = ()
if isinstance(obj, zarr.core.Array):
Expand Down Expand Up @@ -63,6 +106,10 @@ def delete_zbranch(branch: Union[zarr.Group, zarr.Array], compute: bool = True):
)


DEFAULT_ZARR_STORE = FSStorePatched
DEFAULT_N5_STORE = N5FSStorePatched


def delete_zgroup(zgroup: zarr.Group, compute: bool = True):
"""
Delete all arrays in a zarr group
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
copy_array,
ensure_minimum_chunksize,
autoscale_chunk_shape,
write_blocks_delayed,
)
import pytest
import zarr
import numpy as np
import dask
from pydantic_zarr import ArraySpec
from numpy.testing import assert_array_equal
from dask.array.core import slices_from_chunks


def test_ensure_minimum_chunksize():
Expand Down Expand Up @@ -85,3 +89,14 @@ def test_array_copy_from_path(temp_zarr, shape):
copy_op = copy_array(arr_1, arr_2)
dask.compute(copy_op)
assert np.array_equal(arr_2, arr_1)


def test_write_blocks_delayed():
arr = da.random.randint(0, 255, (10, 10, 10), dtype="uint8")
store = zarr.MemoryStore()
arr_spec = ArraySpec.from_array(arr, chunks=(2, 2, 2))
z_arr = arr_spec.to_zarr(store, path="array")
w_ops = write_blocks_delayed(arr, z_arr)
result = dask.compute(w_ops)[0]
assert result == slices_from_chunks(arr.chunks)
assert_array_equal(np.array(arr), z_arr)

0 comments on commit d8c9108

Please sign in to comment.