Skip to content

Commit

Permalink
add from_array method to annotation array attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Aug 14, 2024
1 parent bad8226 commit a133a64
Showing 1 changed file with 49 additions and 9 deletions.
58 changes: 49 additions & 9 deletions src/cellmap_schemas/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from typing import (
Any,
Generic,
List,
Literal,
Mapping,
Optional,
TypeVar,
Union,
)
import numpy as np
from typing_extensions import Self
from pydantic_zarr.v2 import GroupSpec, ArraySpec
from pydantic import BaseModel, model_validator, field_serializer
import zarr
Expand Down Expand Up @@ -111,7 +111,7 @@ class Label(BaseModel, extra="forbid"):


class LabelList(BaseModel, extra="forbid"):
labels: List[Label]
labels: list[Label]
annotation_type: AnnotationType = "semantic"


Expand Down Expand Up @@ -191,7 +191,7 @@ class SemanticSegmentation(BaseModel, extra="forbid"):
"""

type: Literal["semantic_segmentation"] = "semantic_segmentation"
encoding: dict[Union[Possibility, Literal["present"]], int]
encoding: dict[Literal[Possibility, Literal["present"]], int]


class InstanceSegmentation(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -221,7 +221,7 @@ class InstanceSegmentation(BaseModel, extra="forbid"):
encoding: dict[Possibility, int]


AnnotationType = Union[SemanticSegmentation, InstanceSegmentation]
AnnotationType = SemanticSegmentation | InstanceSegmentation

TName = TypeVar("TName", bound=str)

Expand All @@ -247,16 +247,52 @@ class AnnotationArrayAttrs(BaseModel, Generic[TName]):

class_name: TName
# a mapping from values to frequencies
complement_counts: Optional[dict[Possibility, int]]
complement_counts: dict[Possibility, int] | None
# a mapping from class names to values
# this is array metadata because labels might disappear during downsampling
annotation_type: AnnotationType

@model_validator(mode="after")
def check_encoding(self: "AnnotationArrayAttrs"):
assert set(self.annotation_type.encoding.keys()).issuperset((self.complement_counts.keys()))
if (
isinstance(self.annotation_type, SemanticSegmentation)
and self.complement_counts is not None
):
assert set(self.annotation_type.encoding.keys()).issuperset(
(self.complement_counts.keys())
)
return self

@classmethod
def from_array(
cls,
data: np.ndarray,
class_name: TName,
annotation_type: AnnotationType,
complement_counts: None | dict[Possibility, int] | Literal["auto"],
) -> Self:
if complement_counts == "auto":
num_unknown = sum(data == annotation_type.encoding["unknown"])
num_absent = sum(data == annotation_type.encoding["absent"])
num_present = data.size - (num_unknown + num_absent)

if isinstance(annotation_type, SemanticSegmentation):
complement_counts_parsed = {"unknown": num_unknown, "absent": num_absent}
elif isinstance(annotation_type, InstanceSegmentation):
complement_counts_parsed = {
"unknown": num_unknown,
"absent": num_absent,
"present": num_present,
}
else:
complement_counts_parsed = complement_counts

return cls(
class_name=class_name,
annotation_type=annotation_type,
complement_counts=complement_counts_parsed,
)


class AnnotationGroupAttrs(BaseModel, Generic[TName]):
"""
Expand Down Expand Up @@ -327,11 +363,15 @@ class CropGroupAttrs(BaseModel, Generic[TName], validate_assignment=True):
class_names: list[TName]

@field_serializer("start_date")
def ser_end_date(value: date):
def ser_end_date(value: date) -> None | str:
if value is None:
return None
return serialize_date(value)

@field_serializer("end_date")
def ser_start_date(value: date):
def ser_start_date(value: date) -> None | str:
if value is None:
return None
return serialize_date(value)


Expand Down

0 comments on commit a133a64

Please sign in to comment.