diff --git a/.github/workflows/upload_pypi.yml b/.github/workflows/upload_pypi.yml new file mode 100644 index 0000000..1780ce9 --- /dev/null +++ b/.github/workflows/upload_pypi.yml @@ -0,0 +1,24 @@ +name: Upload Python Package + +on: + release: + types: [created] + +jobs: + deploy: + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install poetry=1.4.1 + - name: Build and publish + env: + PYPI_USERNAME: __token__ + PYPI_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + poetry build + poetry publish \ No newline at end of file diff --git a/ground_truth_test.py b/ground_truth_test.py deleted file mode 100755 index ce76432..0000000 --- a/ground_truth_test.py +++ /dev/null @@ -1,85 +0,0 @@ -# from fibsem_tools.metadata.groundtruth import GroupMetadata -from fibsem_tools import read_xarray -import json -from fibsem_tools.metadata.groundtruth import ( - AnnotationEncoding, - classNameDict, - AnnotationArrayAttrs, - AnnotationClassAttrs, - AnnotationCropAttrs, -) -from rich import print_json -import numpy as np -import datetime -from typing import Dict, TypedDict, List, TypeVar - -Key = TypeVar("Key", bound=str) - - -class CropMeta(TypedDict): - maxId: int - name: str - offset: List[float] - offset_unit: str - resolution: List[float] - resulution_unit: str - type: str - - -dataset = "jrc_hela-2" -bucket = "janelia-cosem-datasets" -uri = f"s3://{bucket}/{dataset}/{dataset}.n5/labels/gt/" -out_dtype = "uint8" -out_dtype_max = np.iinfo(out_dtype).max - -crop_key: Key = "Crop13" -group = read_xarray(uri) -arr = group["s0"].data -subvolumeMeta: Dict[Key, CropMeta] = arr.attrs["subvolumes"] -sMeta = subvolumeMeta[crop_key] -dims = ("x", "y", "z") - -scales = arr.attrs["transform"]["scale"][::-1] -offsets = np.multiply(sMeta["offset"], np.divide(scales, sMeta["resolution"])) -selecter = { - d: (np.arange(100) * scale) + offset - for d, offset, scale in zip(dims, offsets, scales) -} - -crop = arr.sel(selecter, method="nearest") -crop_attrs = AnnotationCropAttrs( - name=crop_key, description="A crop", protocol=None, doi=None -) - -out_attrs = {} -out_attrs[f"/{crop_key}"] = {"annotation": crop_attrs.dict()} -# partition the subvolume into separate integer classes -vals = np.unique(crop) - - -for v in vals: - name, description = classNameDict[v].short, classNameDict[v].long - - subvol = (crop == v).astype(out_dtype) - census = {k: np.sum(subvol == k) for k in np.unique(subvol)} - encoding: AnnotationEncoding = {"absent": 0, "unknown": 255} - array_attrs = AnnotationArrayAttrs(census=census, encoding=encoding, object=name) - - group_attrs = AnnotationClassAttrs( - name=name, - description=description, - created_by=[ - "Cellmap annotators", - ], - created_with=["Amira", "Paintera"], - start_date=datetime.datetime.now().isoformat(), - duration_days=10, - encoding=encoding, - type="instance", - ) - - out_attrs[f"/{crop_key}/{name}"] = {"annotation": group_attrs.dict()} - out_attrs[f"/{crop_key}/{name}/s0"] = {"annotation": array_attrs.dict()} - - -print_json(json.dumps(out_attrs)) diff --git a/poetry.lock b/poetry.lock index 8efac61..5ff1394 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "aiobotocore" diff --git a/pyproject.toml b/pyproject.toml index c6661da..4467567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fibsem-tools" -version = "4.0.2" +version = "4.0.4" description = "Tools for processing FIBSEM datasets" authors = ["Davis Vann Bennett "] license = "MIT" diff --git a/src/fibsem_tools/cli/zarr_scan.py b/src/fibsem_tools/cli/zarr_scan.py new file mode 100644 index 0000000..53cf3eb --- /dev/null +++ b/src/fibsem_tools/cli/zarr_scan.py @@ -0,0 +1,106 @@ +from typing import Literal, Union +import click +import zarr +from fibsem_tools import access +from rich import print +from fibsem_tools.io.zarr import get_chunk_keys +from rich.progress import track +import time +from dataclasses import dataclass + +ChunkState = Literal["valid", "missing", "invalid"] + + +@dataclass +class Missing: + variant = "missing" + + +@dataclass +class Invalid: + variant = "invalid" + exception: BaseException + + +@dataclass +class Valid: + variant = "valid" + + +class ChunkSetResults(dict[ChunkState, dict[str, Union[Missing, Valid, Invalid]]]): + pass + + +def check_zarray(array: zarr.Array) -> dict[str, Union[Missing, Invalid, Valid]]: + ckeys = tuple(get_chunk_keys(array)) + results = {} + for ckey in track(ckeys, description="Checking chunks..."): + try: + array._decode_chunk(array.store[ckey]) + results[ckey] = Valid() + except OSError as e: + results[ckey] = Invalid(exception=e) + except KeyError: + results[ckey] = Missing() + + return results + + +@click.command() +@click.argument("array_path", type=click.STRING) +@click.option( + "--valid", + is_flag=True, + show_default=True, + default=False, + help="report valid chunks", +) +@click.option( + "--missing", + is_flag=True, + show_default=True, + default=False, + help="report missing chunks", +) +@click.option( + "--invalid", + is_flag=True, + show_default=True, + default=False, + help="report invalid chunks", +) +@click.option( + "--delete-invalid", + is_flag=True, + show_default=True, + default=False, + help="delete invalid chunks", +) +def cli(array_path, valid, missing, invalid, delete_invalid): + start = time.time() + array = access(array_path, mode="r") + all_results = check_zarray(array) + # categorize + results_categorized: ChunkSetResults = {"valid": {}, "missing": {}, "invalid": {}} + for key, value in all_results.items(): + results_categorized[value.variant][key] = value + + to_show = {} + + for flag, opt in zip((valid, missing, invalid), ("valid", "missing", "invalid")): + if flag: + to_show[opt] = results_categorized[opt] + print(to_show) + if delete_invalid: + array_a = access(array_path, mode="a") + num_invalid = len(results_categorized["invalid"]) + for res in track( + results_categorized["invalid"], + description=f"Deleting {num_invalid} invalid chunks...", + ): + del array_a.store[res] + print(f"Completed after {time.time() - start}s") + + +if __name__ == "__main__": + cli() diff --git a/src/fibsem_tools/io/zarr.py b/src/fibsem_tools/io/zarr.py index 359f2e6..5c5ba13 100644 --- a/src/fibsem_tools/io/zarr.py +++ b/src/fibsem_tools/io/zarr.py @@ -216,7 +216,7 @@ def access_zarr( array_or_group = zarr.open(store, path=path, **kwargs, mode=access_mode) - if access_mode != "r": + if access_mode != "r" and len(attrs) > 0: array_or_group.attrs.update(attrs) return array_or_group @@ -329,9 +329,10 @@ def infer_coords(array: zarr.Array) -> List[DataArray]: elif (multiscales := group.attrs.get("multiscales", None)) is not None: if len(multiscales) > 0: multiscale = multiscales[0] - if (ngff_version := multiscale.get("version", None)) == "0.4": + ngff_version = multiscale.get("version", None) + if ngff_version == "0.4": from pydantic_ome_ngff.v04 import Multiscale - elif multiscale["version"] == "0.5-dev": + elif ngff_version == "0.5-dev": from pydantic_ome_ngff.latest import Multiscale else: raise ValueError( @@ -341,7 +342,7 @@ def infer_coords(array: zarr.Array) -> List[DataArray]: """ ) else: - raise ValueError("Multiscales attribute was empty") + raise ValueError("Multiscales attribute was empty.") xarray_adapters = get_adapters(ngff_version) multiscales_meta = [Multiscale(**entry) for entry in multiscales] transforms = [] diff --git a/src/fibsem_tools/metadata/groundtruth.py b/src/fibsem_tools/metadata/groundtruth.py index ed5fe1a..1179cd2 100644 --- a/src/fibsem_tools/metadata/groundtruth.py +++ b/src/fibsem_tools/metadata/groundtruth.py @@ -1,11 +1,14 @@ from __future__ import annotations from enum import Enum -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, Generic, List, Literal, Optional, TypeVar, Union -from pydantic import BaseModel +from pydantic import BaseModel, root_validator +from pydantic.generics import GenericModel -AnnotationType = Union[Literal["semantic"], Literal["instance"]] +class StrictBase(BaseModel): + class Config: + extra = "forbid" class InstanceName(BaseModel): @@ -40,10 +43,7 @@ class LabelList(BaseModel): 1: InstanceName(short="ECS", long="Extracellular Space"), 2: InstanceName(short="Plasma membrane", long="Plasma membrane"), 3: InstanceName(short="Mito membrane", long="Mitochondrial membrane"), - 4: InstanceName( - short="Mito lumen", - long="Mitochondrial lumen", - ), + 4: InstanceName(short="Mito lumen",long="Mitochondrial lumen"), 5: InstanceName(short="Mito DNA", long="Mitochondrial DNA"), 6: InstanceName(short="Golgi Membrane", long="Golgi apparatus membrane"), 7: InstanceName(short="Golgi lumen", long="Golgi apparatus lumen"), @@ -57,9 +57,7 @@ class LabelList(BaseModel): 15: InstanceName(short="LD lumen", long="Lipid droplet lumen"), 16: InstanceName(short="ER membrane", long="Endoplasmic reticulum membrane"), 17: InstanceName(short="ER lumen", long="Endoplasmic reticulum membrane"), - 18: InstanceName( - short="ERES membrane", long="Endoplasmic reticulum exit site membrane" - ), + 18: InstanceName(short="ERES membrane", long="Endoplasmic reticulum exit site membrane"), 19: InstanceName(short="ERES lumen", long="Endoplasmic reticulum exit site lumen"), 20: InstanceName(short="NE membrane", long="Nuclear envelope membrane"), 21: InstanceName(short="NE lumen", long="Nuclear envelope lumen"), @@ -81,38 +79,59 @@ class LabelList(BaseModel): 37: InstanceName(short="Nucleus combined", long="Nucleus combined"), 38: InstanceName(short="Vimentin", long="Vimentin"), 39: InstanceName(short="Glycogen", long="Glycogen"), + 40: InstanceName(short="Cardiac neurons", long="Cardiac neurons"), + 41: InstanceName(short="Endothelial cells", long="Endothelial cells"), + 42: InstanceName(short="Cardiomyocytes", long="Cardiomyocytes"), + 43: InstanceName(short="Epicardial cells", long="Epicardial cells"), + 44: InstanceName(short="Parietal pericardial cells", long="Parietal pericardial cells"), + 45: InstanceName(short="Red blood cells", long="Red blood cells"), + 46: InstanceName(short="White blood cells", long="White blood cells"), + 47: InstanceName(short="Peroxisome membrane", long="Peroxisome membrane"), + 48: InstanceName(short="Peroxisome lumen", long="Peroxisome lumen"), } +Possibility = Literal["unknown", "absent"] -class SemanticAnnotation(BaseModel): - type: Literal["semantic"] - encoding: Dict[int, str] +class SemanticSegmentation(BaseModel): + type: Literal["semantic_segmentation"] = "semantic_segmentation" + encoding: Dict[Union[Possibility, Literal["present"]], int] -class InstanceAnnotation(BaseModel): - type: Literal["instance"] - encoding: Dict[int, Possibility] +class InstanceSegmentation(BaseModel): + type: Literal["instance_segmentation"] = "instance_segmentation" + encoding: Dict[Possibility, int] -Possibility = Union[Literal["unknown"], Literal["absent"], Literal["present"]] -AnnotationEncoding = Dict[Possibility, int] +AnnotationType = Union[SemanticSegmentation, InstanceSegmentation] +TName = TypeVar("TName", bound=str) -class AnnotationArrayAttrs(BaseModel): + +class AnnotationArrayAttrs(GenericModel, Generic[TName]): """ The metadata for an array of annotated values. """ - objects: str + class_name: TName # a mapping from values to frequencies - census: Dict[int, int] + histogram: Optional[Dict[Possibility, int]] # a mapping from class names to values # this is array metadata because labels might disappear during downsampling - encoding: AnnotationEncoding + annotation_type: AnnotationType + + @root_validator() + def check_encoding(cls, values): + if (typ := values.get("type", False)) and ( + hist := values.get("histogram", False) + ): + # check that everything in the histogram is encoded + assert set(typ.encoding.keys()).issuperset((hist.keys())), "Oh no" + + return values -class AnnotationClassAttrs(BaseModel): +class MultiscaleGroupAttrs(GenericModel, Generic[TName]): """ The metadata for an individual annotated semantic class. In a storage hierarchy like zarr or hdf5, this metadata is associated with a @@ -120,23 +139,30 @@ class AnnotationClassAttrs(BaseModel): annotation data in a multiscale representation. """ - name: str + class_name: TName description: str - created_by: List[str] - created_with: List[str] + created_by: list[str] + created_with: list[str] start_date: str | None end_date: str | None duration_days: int | None - type: AnnotationType - encoding: AnnotationEncoding + annotation_type: AnnotationType + + +class AnnotationProtocol(GenericModel, Generic[TName]): + url: str + class_names: list[TName] + + class Config: + allow_extra = "forbid" -class AnnotationCropAttrs(BaseModel): +class AnnotationCropAttrs(GenericModel, Generic[TName]): """ The metadata for all annotations in a single crop. """ name: Optional[str] description: Optional[str] - protocol: Optional[str] + protocol: AnnotationProtocol[TName] doi: Optional[str] diff --git a/tests/test_zarr.py b/tests/test_zarr.py index 8e07ae1..790c3c0 100644 --- a/tests/test_zarr.py +++ b/tests/test_zarr.py @@ -3,8 +3,10 @@ import pytest from xarray import DataArray from zarr.storage import FSStore +from pathlib import Path import zarr import numpy as np +import itertools from fibsem_tools.io.core import read_dask, read_xarray from fibsem_tools.io.multiscale import multiscale_group from fibsem_tools.io.xr import stt_from_array @@ -15,6 +17,7 @@ access_zarr, create_dataarray, create_datatree, + get_chunk_keys, get_url, to_dask, to_xarray, @@ -248,3 +251,23 @@ def test_dask(temp_zarr, chunks): assert np.array_equal(observed, data) assert np.array_equal(read_dask(get_url(zarray), chunks).compute(), data) + + +@pytest.mark.parametrize( + "store_class", (zarr.N5Store, zarr.DirectoryStore, zarr.NestedDirectoryStore) +) +@pytest.mark.parametrize("shape", ((10,), (10, 11, 12))) +def test_chunk_keys(tmp_path: Path, store_class, shape): + store: zarr.storage.BaseStore = store_class(tmp_path) + arr_path = "test" + arr = zarr.create( + shape=shape, store=store, path=arr_path, chunks=(2,) * len(shape), dtype="uint8" + ) + + dim_sep = arr._dimension_separator + chunk_idcs = itertools.product(*(range(c_s) for c_s in arr.cdata_shape)) + expected = tuple( + os.path.join(arr.path, dim_sep.join(map(str, idx))) for idx in chunk_idcs + ) + observed = tuple(get_chunk_keys(arr)) + assert observed == expected