Skip to content

Commit

Permalink
Merge pull request #14 from janelia-cellmap/feature/from-array
Browse files Browse the repository at this point in the history
Feature/from array
  • Loading branch information
d-v-b authored Sep 1, 2024
2 parents c0677a3 + 72db1e9 commit b883824
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "cellmap-schemas"
dynamic = ["version"]
description = 'Schemas for data used by the Cellmap project team at Janelia Research Campus.'
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = "MIT"
keywords = ["cellmap", "ngff", "n5", "zarr"]
authors = [
Expand Down Expand Up @@ -45,7 +45,7 @@ dependencies = [
]

[[tool.hatch.envs.test.matrix]]
python = ["3.9", "3.10", "3.11"]
python = ["3.10", "3.11", "3.12"]

[tool.hatch.envs.test.scripts]
run-coverage = "pytest --cov-config=pyproject.toml --cov=pkg --cov=tests"
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_schemas/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2022-present Howard Hughes Medical Institute
#
# SPDX-License-Identifier: MIT
__version__ = "0.7.0"
__version__ = "0.7.1"
77 changes: 68 additions & 9 deletions src/cellmap_schemas/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from typing import (
Any,
Generic,
List,
Literal,
Mapping,
Optional,
TypeVar,
Union,
)
import numpy as np
import numpy.typing as npt
from typing_extensions import Self
from pydantic_zarr.v2 import GroupSpec, ArraySpec
from pydantic import BaseModel, model_validator, field_serializer
import zarr
Expand Down Expand Up @@ -111,7 +112,7 @@ class Label(BaseModel, extra="forbid"):


class LabelList(BaseModel, extra="forbid"):
labels: List[Label]
labels: list[Label]
annotation_type: AnnotationType = "semantic"


Expand Down Expand Up @@ -191,7 +192,7 @@ class SemanticSegmentation(BaseModel, extra="forbid"):
"""

type: Literal["semantic_segmentation"] = "semantic_segmentation"
encoding: dict[Union[Possibility, Literal["present"]], int]
encoding: dict[Literal[Possibility, Literal["present"]], int]


class InstanceSegmentation(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -221,7 +222,7 @@ class InstanceSegmentation(BaseModel, extra="forbid"):
encoding: dict[Possibility, int]


AnnotationType = Union[SemanticSegmentation, InstanceSegmentation]
AnnotationType = SemanticSegmentation | InstanceSegmentation

TName = TypeVar("TName", bound=str)

Expand All @@ -247,16 +248,55 @@ class AnnotationArrayAttrs(BaseModel, Generic[TName]):

class_name: TName
# a mapping from values to frequencies
complement_counts: Optional[dict[Possibility, int]]
complement_counts: dict[Possibility, int] | None
# a mapping from class names to values
# this is array metadata because labels might disappear during downsampling
annotation_type: AnnotationType

@model_validator(mode="after")
def check_encoding(self: "AnnotationArrayAttrs"):
assert set(self.annotation_type.encoding.keys()).issuperset((self.complement_counts.keys()))
if (
isinstance(self.annotation_type, SemanticSegmentation)
and self.complement_counts is not None
):
assert set(self.annotation_type.encoding.keys()).issuperset(
(self.complement_counts.keys())
)
return self

@classmethod
def from_array(
cls,
array: np.ndarray,
class_name: TName,
annotation_type: AnnotationType,
complement_counts: None | dict[Possibility, int] | Literal["auto"],
) -> Self:
if complement_counts == "auto":
num_unknown = (array == annotation_type.encoding["unknown"]).sum()
num_absent = (array == annotation_type.encoding["absent"]).sum()
num_present = array.size - (num_unknown + num_absent)

if isinstance(annotation_type, SemanticSegmentation):
complement_counts_parsed = {
"unknown": num_unknown,
"absent": num_absent,
"present": num_present,
}
elif isinstance(annotation_type, InstanceSegmentation):
complement_counts_parsed = {
"unknown": num_unknown,
"absent": num_absent,
}
else:
complement_counts_parsed = complement_counts

return cls(
class_name=class_name,
annotation_type=annotation_type,
complement_counts=complement_counts_parsed,
)


class AnnotationGroupAttrs(BaseModel, Generic[TName]):
"""
Expand Down Expand Up @@ -327,11 +367,15 @@ class CropGroupAttrs(BaseModel, Generic[TName], validate_assignment=True):
class_names: list[TName]

@field_serializer("start_date")
def ser_end_date(value: date):
def ser_end_date(value: date) -> None | str:
if value is None:
return None
return serialize_date(value)

@field_serializer("end_date")
def ser_start_date(value: date):
def ser_start_date(value: date) -> None | str:
if value is None:
return None
return serialize_date(value)


Expand All @@ -352,6 +396,21 @@ class AnnotationArray(ArraySpec):

attributes: CellmapWrapper[AnnotationWrapper[AnnotationArrayAttrs]]

@classmethod
def from_array_infer_attrs(
cls,
array: npt.NDArray[Any],
class_name: TName,
annotation_type: AnnotationType,
complement_counts: None | dict[Possibility, int] | Literal["auto"],
**kwargs,
) -> Self:
annotation_attrs = AnnotationArrayAttrs.from_array(
array, class_name, annotation_type, complement_counts
)
annotation_attrs_wrapped = wrap_attributes(annotation_attrs)
return super().from_array(array, attributes=annotation_attrs_wrapped.model_dump(), **kwargs)


class AnnotationGroup(GroupSpec):
"""
Expand Down

0 comments on commit b883824

Please sign in to comment.