Skip to content

Commit

Permalink
feat: add zarr scanning cli tool to check for invalid chunks and pote…
Browse files Browse the repository at this point in the history
…ntially delete them
  • Loading branch information
d-v-b committed Aug 22, 2023
1 parent 0acda4b commit aa529d6
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 1 deletion.
106 changes: 106 additions & 0 deletions src/fibsem_tools/cli/zarr_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Literal, Union
import click
import zarr
from fibsem_tools import access
from rich import print
from fibsem_tools.io.zarr import get_chunk_keys
from rich.progress import track
import time
from dataclasses import dataclass

ChunkState = Literal["valid", "missing", "invalid"]


@dataclass
class Missing:
variant = "missing"


@dataclass
class Invalid:
variant = "invalid"
exception: BaseException


@dataclass
class Valid:
variant = "valid"


class ChunkSetResults(dict[ChunkState, dict[str, Union[Missing, Valid, Invalid]]]):
pass


def check_zarray(array: zarr.Array) -> dict[str, Union[Missing, Invalid, Valid]]:
ckeys = tuple(get_chunk_keys(array))
results = {}
for ckey in track(ckeys, description="Checking chunks..."):
try:
array._decode_chunk(array.store[ckey])
results[ckey] = Valid()
except OSError as e:
results[ckey] = Invalid(exception=e)
except KeyError:
results[ckey] = Missing()

return results


@click.command()
@click.argument("array_path", type=click.STRING)
@click.option(
"--valid",
is_flag=True,
show_default=True,
default=False,
help="report valid chunks",
)
@click.option(
"--missing",
is_flag=True,
show_default=True,
default=False,
help="report missing chunks",
)
@click.option(
"--invalid",
is_flag=True,
show_default=True,
default=False,
help="report invalid chunks",
)
@click.option(
"--delete-invalid",
is_flag=True,
show_default=True,
default=False,
help="delete invalid chunks",
)
def cli(array_path, valid, missing, invalid, delete_invalid):
start = time.time()
array = access(array_path, mode="r")
all_results = check_zarray(array)
# categorize
results_categorized: ChunkSetResults = {"valid": {}, "missing": {}, "invalid": {}}
for key, value in all_results.items():
results_categorized[value.variant][key] = value

to_show = {}

for flag, opt in zip((valid, missing, invalid), ("valid", "missing", "invalid")):
if flag:
to_show[opt] = results_categorized[opt]
print(to_show)
if delete_invalid:
array_a = access(array_path, mode="a")
num_invalid = len(results_categorized["invalid"])
for res in track(
results_categorized["invalid"],
description=f"Deleting {num_invalid} invalid chunks...",
):
del array_a.store[res]
print(f"Completed after {time.time() - start}s")


if __name__ == "__main__":
cli()
2 changes: 1 addition & 1 deletion src/fibsem_tools/io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def access_zarr(

array_or_group = zarr.open(store, path=path, **kwargs, mode=access_mode)

if access_mode != "r":
if access_mode != "r" and len(attrs) > 0:
array_or_group.attrs.update(attrs)
return array_or_group

Expand Down
23 changes: 23 additions & 0 deletions tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import pytest
from xarray import DataArray
from zarr.storage import FSStore
from pathlib import Path
import zarr
import numpy as np
import itertools
from fibsem_tools.io.core import read_dask, read_xarray
from fibsem_tools.io.multiscale import multiscale_group
from fibsem_tools.io.xr import stt_from_array
Expand All @@ -15,6 +17,7 @@
access_zarr,
create_dataarray,
create_datatree,
get_chunk_keys,
get_url,
to_dask,
to_xarray,
Expand Down Expand Up @@ -247,3 +250,23 @@ def test_dask(temp_zarr, chunks):
assert np.array_equal(observed, data)

assert np.array_equal(read_dask(get_url(zarray), chunks).compute(), data)


@pytest.mark.parametrize(
"store_class", (zarr.N5Store, zarr.DirectoryStore, zarr.NestedDirectoryStore)
)
@pytest.mark.parametrize("shape", ((10,), (10, 11, 12)))
def test_chunk_keys(tmp_path: Path, store_class, shape):
store: zarr.storage.BaseStore = store_class(tmp_path)
arr_path = "test"
arr = zarr.create(
shape=shape, store=store, path=arr_path, chunks=(2,) * len(shape), dtype="uint8"
)

dim_sep = arr._dimension_separator
chunk_idcs = itertools.product(*(range(c_s) for c_s in arr.cdata_shape))
expected = tuple(
os.path.join(arr.path, dim_sep.join(map(str, idx))) for idx in chunk_idcs
)
observed = tuple(get_chunk_keys(arr))
assert observed == expected

0 comments on commit aa529d6

Please sign in to comment.