Skip to content

Commit

Permalink
Merge pull request #57 from janelia-cosem/groundtruth_model
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b authored Sep 15, 2023
2 parents 0b237e0 + d8c9108 commit 34f6d67
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 49 deletions.
78 changes: 59 additions & 19 deletions src/fibsem_tools/io/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import fuse
from dask.utils import is_arraylike, parse_bytes
from dask.utils import parse_bytes
from zarr.util import normalize_chunks as normalize_chunksize
from numpy.typing import NDArray, DTypeLike
import random
from fibsem_tools.io.core import access, read
from fibsem_tools.io.zarr import are_chunks_aligned
from dask import delayed

random.seed(0)

Expand Down Expand Up @@ -65,33 +66,49 @@ def sequential_rechunk(


@backoff.on_exception(backoff.expo, (ServerDisconnectedError, OSError))
def store_chunk(x: NDArray[Any], out: Any, index: Tuple[slice, ...]) -> Literal[0]:
def store_chunk(
target: NDArray[Any], key: Tuple[slice, ...], value: NDArray[Any]
) -> Literal[0]:
"""
A function inserted in a Dask graph for storing a chunk.
Parameters
----------
x: array-like
An array (potentially a NumPy one)
out: array-like
Where to store results to.
index: slice-like
Where to store result from ``x`` in ``out``.
target: NDArray
Where to store the value.
key: Tuple[slice, ...]
The location in the array for the value.
value: NDArray
The value to be stored.
Examples
--------
"""
target[key] = value
return 0

>>> a = np.ones((5, 6))
>>> b = np.empty(a.shape)
>>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)

@backoff.on_exception(backoff.expo, (ServerDisconnectedError, OSError))
def store_value(
target: NDArray[Any], key: Tuple[slice, ...], value: NDArray[Any]
) -> Literal[0]:
"""
A function inserted in a Dask graph for storing a chunk.
if is_arraylike(x):
out[index] = x
else:
out[index] = np.asanyarray(x)
Parameters
----------
target: NDArray
Where to store the value.
key: Tuple[slice, ...]
The location in the array for the value.
value: NDArray
The value to be stored.
return 0
Examples
--------
"""
target[key] = value
return key


def ndwrapper(func: Callable[[Any], Any], ndim: int, *args: Any, **kwargs: Any):
Expand Down Expand Up @@ -124,14 +141,14 @@ def write_blocks(source, target, region: Optional[Tuple[slice, ...]]) -> da.Arra
dsk = {}
chunks = tuple((1,) * s for s in source.blocks.shape)

for slice, key in zip(slices, flatten(source.__dask_keys__())):
for slce, key in zip(slices, flatten(source.__dask_keys__())):
dsk[(store_name,) + key[1:]] = (
ndwrapper,
store_chunk,
source.ndim,
key,
target,
slice,
slce,
key,
)

layers[store_name] = dsk
Expand Down Expand Up @@ -181,6 +198,29 @@ def store_blocks(sources, targets, regions: Optional[slice] = None) -> List[da.A
return result


def write_blocks_delayed(
source, target, region: Optional[Tuple[slice, ...]] = None
) -> Sequence[Any]:
"""
Return a collection fo task each task returns the result of writing
each chunk of `source` to `target`.
"""

# handle xarray
if hasattr(source, "data") and isinstance(source.data, da.Array):
source = source.data

slices = slices_from_chunks(source.chunks)
if region:
slices = [fuse_slice(region, slc) for slc in slices]
blocks_flat = source.blocks.ravel()
assert len(slices) == len(blocks_flat)
return [
delayed(store_value)(target, slce, block)
for slce, block in zip(slices, blocks_flat)
]


def ensure_minimum_chunksize(array, chunksize):
old_chunks = np.array(array.chunksize)
new_chunks = old_chunks.copy()
Expand Down
80 changes: 50 additions & 30 deletions src/fibsem_tools/io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
from os import PathLike
import time
from pathlib import Path
from typing import Any, Dict, Generator, List, Literal, Sequence, Tuple, Union
from datatree import DataTree
Expand All @@ -12,7 +11,7 @@
import numpy as np
import xarray
import zarr
from zarr.storage import FSStore, contains_array, contains_group
from zarr.storage import FSStore
from dask import bag, delayed
from distributed import Client, Lock
from toolz import concat
Expand All @@ -21,6 +20,7 @@
from fibsem_tools.io.xr import stt_coord
from fibsem_tools.metadata.transform import STTransform
from xarray_ome_ngff.registry import get_adapters
from zarr.errors import ReadOnlyError

ureg = pint.UnitRegistry()

Expand All @@ -31,11 +31,53 @@
# default axis order of raw n5 spatial metadata
# is x,y,z
N5_AXES_3D = ZARR_AXES_3D[::-1]
DEFAULT_ZARR_STORE = FSStore
DEFAULT_N5_STORE = zarr.N5FSStore
logger = logging.getLogger(__name__)


class FSStorePatched(FSStore):
"""
Patch delitems to delete "blind", i.e. without checking if to-be-deleted keys exist.
This is temporary and should be removed when
https://github.com/zarr-developers/zarr-python/issues/1336
is resolved.
"""

def delitems(self, keys):
if self.mode == "r":
raise ReadOnlyError()
try: # should much faster
nkeys = [self._normalize_key(key) for key in keys]
# rm errors if you pass an empty collection
self.map.delitems(nkeys)
except FileNotFoundError:
nkeys = [self._normalize_key(key) for key in keys if key in self]
# rm errors if you pass an empty collection
if len(nkeys) > 0:
self.map.delitems(nkeys)


class N5FSStorePatched(zarr.N5FSStore):
"""
Patch delitems to delete "blind", i.e. without checking if to-be-deleted keys exist.
This is temporary and should be removed when
https://github.com/zarr-developers/zarr-python/issues/1336
is resolved.
"""

def delitems(self, keys):
if self.mode == "r":
raise ReadOnlyError()
try: # should much faster
nkeys = [self._normalize_key(key) for key in keys]
# rm errors if you pass an empty collection
self.map.delitems(nkeys)
except FileNotFoundError:
nkeys = [self._normalize_key(key) for key in keys if key in self]
# rm errors if you pass an empty collection
if len(nkeys) > 0:
self.map.delitems(nkeys)


def get_arrays(obj: Any) -> Tuple[zarr.Array]:
result = ()
if isinstance(obj, zarr.core.Array):
Expand Down Expand Up @@ -64,6 +106,10 @@ def delete_zbranch(branch: Union[zarr.Group, zarr.Array], compute: bool = True):
)


DEFAULT_ZARR_STORE = FSStorePatched
DEFAULT_N5_STORE = N5FSStorePatched


def delete_zgroup(zgroup: zarr.Group, compute: bool = True):
"""
Delete all arrays in a zarr group
Expand Down Expand Up @@ -188,32 +234,6 @@ def access_zarr(
attrs = kwargs.pop("attrs", {})
access_mode = kwargs.pop("mode", "a")

if access_mode == "w":
if contains_group(store, path) or contains_array(store, path):
# zarr is extremely slow to delete existing directories, so we do it in
# parallel
existing = zarr.open(store, path=path, **kwargs, mode="a")
# todo: move this logic to methods on the stores themselves
if isinstance(
existing.store,
(
zarr.N5Store,
zarr.N5FSStore,
zarr.DirectoryStore,
zarr.NestedDirectoryStore,
),
):
url = os.path.join(existing.store.path, existing.path)
logger.info(f"Beginning parallel deletion of chunks in {url}...")
pre = time.time()
delete_zbranch(existing)
logger.info(
f"""
Completed parallel deletion of chunks in {url} in
{time.time() - pre}s.
"""
)

array_or_group = zarr.open(store, path=path, **kwargs, mode=access_mode)

if access_mode != "r" and len(attrs) > 0:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
copy_array,
ensure_minimum_chunksize,
autoscale_chunk_shape,
write_blocks_delayed,
)
import pytest
import zarr
import numpy as np
import dask
from pydantic_zarr import ArraySpec
from numpy.testing import assert_array_equal
from dask.array.core import slices_from_chunks


def test_ensure_minimum_chunksize():
Expand Down Expand Up @@ -85,3 +89,14 @@ def test_array_copy_from_path(temp_zarr, shape):
copy_op = copy_array(arr_1, arr_2)
dask.compute(copy_op)
assert np.array_equal(arr_2, arr_1)


def test_write_blocks_delayed():
arr = da.random.randint(0, 255, (10, 10, 10), dtype="uint8")
store = zarr.MemoryStore()
arr_spec = ArraySpec.from_array(arr, chunks=(2, 2, 2))
z_arr = arr_spec.to_zarr(store, path="array")
w_ops = write_blocks_delayed(arr, z_arr)
result = dask.compute(w_ops)[0]
assert result == slices_from_chunks(arr.chunks)
assert_array_equal(np.array(arr), z_arr)

0 comments on commit 34f6d67

Please sign in to comment.