Skip to content

Commit

Permalink
refactor: ♻️ Switch from "if hasattr" to try/except
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Aug 9, 2024
1 parent d78764f commit c4c07ca
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 182 deletions.
6 changes: 5 additions & 1 deletion src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .subdataset import CellMapSubset
from typing import Callable, Iterable, Optional
from typing import Callable, Iterable, Optional, Sequence


class CellMapDataLoader:
Expand Down Expand Up @@ -91,6 +91,10 @@ def __init__(
# TODO: Try persistent workers
self.loader = DataLoader(**kwargs)

def __getitem__(self, indices: Sequence[int]) -> dict:
"""Get an item from the DataLoader."""
return self.collate_fn([self.loader.dataset[index] for index in indices])

def refresh(self):
"""If the sampler is a Callable, refresh the DataLoader with the current sampler."""
if isinstance(self.sampler, Callable):
Expand Down
164 changes: 102 additions & 62 deletions src/cellmap_data/dataset.py

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
import os
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
from typing import Any, Callable, Mapping, Optional, Sequence
import tensorstore
import torch
from .dataset import CellMapDataset
Expand Down Expand Up @@ -208,7 +208,9 @@ def __repr__(self) -> str:
@property
def train_datasets_combined(self) -> CellMapMultiDataset:
"""A multi-dataset from the combination of all training datasets."""
if not hasattr(self, "_train_datasets_combined"):
try:
return self._train_datasets_combined
except AttributeError:
self._train_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
Expand All @@ -219,13 +221,15 @@ def train_datasets_combined(self) -> CellMapMultiDataset:
if self.force_has_data or ds.has_data
],
)
return self._train_datasets_combined
return self._train_datasets_combined

@property
def validation_datasets_combined(self) -> CellMapMultiDataset:
"""A multi-dataset from the combination of all validation datasets."""
assert len(self.validation_datasets) > 0, "Validation datasets not loaded."
if not hasattr(self, "_validation_datasets_combined"):
try:
return self._validation_datasets_combined
except AttributeError:
self._validation_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
Expand All @@ -236,29 +240,33 @@ def validation_datasets_combined(self) -> CellMapMultiDataset:
if self.force_has_data or ds.has_data
],
)
return self._validation_datasets_combined
return self._validation_datasets_combined

@property
def validation_blocks(self) -> CellMapSubset:
"""A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks."""
if not hasattr(self, "_validation_blocks"):
try:
return self._validation_blocks
except AttributeError:
self._validation_blocks = CellMapSubset(
self.validation_datasets_combined,
self.validation_datasets_combined.validation_indices,
)
return self._validation_blocks
return self._validation_blocks

@property
def class_counts(self) -> Dict[str, Dict[str, int]]:
def class_counts(self) -> dict[str, dict[str, float]]:
"""A dictionary containing the class counts for the training and validation datasets."""
if not hasattr(self, "_class_counts"):
try:
return self._class_counts
except AttributeError:
self._class_counts = {
"train": self.train_datasets_combined.class_counts,
"validate": self.validation_datasets_combined.class_counts,
}
return self._class_counts
return self._class_counts

def from_csv(self, csv_path) -> Dict[str, Sequence[Dict[str, str]]]:
def from_csv(self, csv_path) -> dict[str, Sequence[dict[str, str]]]:
"""Loads the dataset_dict data from a csv file."""
dataset_dict = {}
with open(csv_path, "r") as f:
Expand Down
103 changes: 68 additions & 35 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,36 +145,44 @@ def __repr__(self) -> str:
@property
def shape(self) -> Mapping[str, int]:
"""Returns the shape of the image."""
if not hasattr(self, "_shape"):
try:
return self._shape
except AttributeError:
self._shape = {
c: self.group[self.scale_level].shape[i]
for i, c in enumerate(self.axes)
}
return self._shape
return self._shape

@property
def center(self) -> Mapping[str, float]:
"""Returns the center of the image in world units."""
if not hasattr(self, "_center"):
try:
return self._center
except AttributeError:
center = {}
for c, (start, stop) in self.bounding_box.items():
center[c] = start + (stop - start) / 2
self._center = center
return self._center
return self._center

@property
def multiscale_attrs(self) -> GroupAttrs:
"""Returns the multiscale metadata of the image."""
if not hasattr(self, "_multiscale_attrs"):
try:
return self._multiscale_attrs
except AttributeError:
self._multiscale_attrs = GroupAttrs(
multiscales=self.group.attrs["multiscales"]
).multiscales[0]
return self._multiscale_attrs
return self._multiscale_attrs

@property
def coordinateTransformations(self) -> list[Mapping[str, Any]]:
"""Returns the coordinate transformations of the image, based on the multiscale metadata."""
if not hasattr(self, "_coordinateTransformations"):
try:
return self._coordinateTransformations
except AttributeError:
# multi_tx = multi.coordinateTransformations
dset = [
ds
Expand All @@ -183,77 +191,95 @@ def coordinateTransformations(self) -> list[Mapping[str, Any]]:
][0]
# tx_fused = normalize_transforms(multi_tx, dset.coordinateTransformations)
self._coordinateTransformations = dset.coordinateTransformations
return self._coordinateTransformations
return self._coordinateTransformations

@property
def full_coords(self) -> Mapping[str, xarray.DataArray]:
"""Returns the full coordinates of the image's axes in world units."""
if not hasattr(self, "_full_coords"):
try:
return self._full_coords
except AttributeError:
self._full_coords = coords_from_transforms(
axes=self.multiscale_attrs.axes,
transforms=self.coordinateTransformations,
# transforms=tx_fused,
shape=self.group[self.scale_level].shape,
)
return self._full_coords
return self._full_coords

@property
def scale_level(self) -> str:
"""Returns the multiscale level of the image."""
if not hasattr(self, "_scale_level"):
try:
return self._scale_level
except AttributeError:
self._scale_level = self.find_level(self.scale)
return self._scale_level
return self._scale_level

@property
def group(self) -> zarr.Group:
"""Returns the zarr group object for the multiscale image."""
if not hasattr(self, "_group"):
self._group = zarr.open_group(self.path)
return self._group # type: ignore
try:
return self._group
except AttributeError:
if self.path[:5] == "s3://":
self._group = zarr.open_group(zarr.N5FSStore(self.path, anon=True))
else:
self._group = zarr.open_group(self.path)
return self._group

@property
def array_path(self) -> str:
"""Returns the path to the single-scale image array."""
if not hasattr(self, "_array_path"):
try:
return self._array_path
except AttributeError:
self._array_path = os.path.join(self.path, self.scale_level)
return self._array_path
return self._array_path

@property
def array(self) -> xarray.DataArray:
"""Returns the image data as an xarray DataArray."""
# TODO: Would it be faster to do Try/Except instead of hasattr?
if not hasattr(self, "_array"):
try:
return self._array
except AttributeError:
# Construct an xarray with Tensorstore backend
spec = xt._zarr_spec_from_path(self.array_path)
array_future = tensorstore.open( # type: ignore
array_future = tensorstore.open(
spec, read=True, write=False, context=self.context
)
array = array_future.result()
data = xt._TensorStoreAdapter(array)
self._array = xarray.DataArray(data=data, coords=self.full_coords)
return self._array
return self._array

@property
def translation(self) -> Mapping[str, float]:
"""Returns the translation of the image."""
if not hasattr(self, "_translation"):
try:
return self._translation
except AttributeError:
# Get the translation of the image
self._translation = {c: self.bounding_box[c][0] for c in self.axes}
return self._translation
return self._translation

@property
def bounding_box(self) -> Mapping[str, list[float]]:
"""Returns the bounding box of the dataset in world units."""
if not hasattr(self, "_bounding_box"):
try:
return self._bounding_box
except AttributeError:
self._bounding_box = {}
for coord in self.full_coords:
self._bounding_box[coord.dims[0]] = [coord.data.min(), coord.data.max()]
return self._bounding_box
return self._bounding_box

@property
def sampling_box(self) -> Mapping[str, list[float]]:
"""Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box), in world units."""
if not hasattr(self, "_sampling_box") or self._sampling_box is None:
try:
return self._sampling_box
except AttributeError:
self._sampling_box = {}
output_padding = {c: np.ceil(s / 2) for c, s in self.output_size.items()}
for c, (start, stop) in self.bounding_box.items():
Expand All @@ -274,35 +300,39 @@ def sampling_box(self) -> Mapping[str, list[float]]:
else:
self._sampling_box = None
raise e
return self._sampling_box
return self._sampling_box

@property
def bg_count(self) -> float:
"""Returns the number of background pixels in the ground truth data, normalized by the resolution."""
if hasattr(self, "_bg_count"):
try:
return self._bg_count
except AttributeError:
# Get from cellmap-schemas metadata, then normalize by resolution - get class counts at same time
_ = self.class_counts
return self._bg_count
return self._bg_count

@property
def class_counts(self) -> float:
"""Returns the number of pixels for the contained class in the ground truth data, normalized by the resolution."""
if not hasattr(self, "_class_counts"):
try:
return self._class_counts # type: ignore
except AttributeError:
# Get from cellmap-schemas metadata, then normalize by resolution
try:
# TODO: Make work with HDF5 files
bg_count = self.group[self.scale_level].attrs["cellmap"]["annotation"][
"complement_counts"
]["absent"]
self._class_counts = (
np.prod(self.group[self.scale_level].shape) - bg_count # type: ignore
np.prod(self.group[self.scale_level].shape) - bg_count
) * np.prod(list(self.scale.values()))
self._bg_count = bg_count * np.prod(list(self.scale.values()))
except Exception as e:
print(e)
self._class_counts = 0.1
self._bg_count = 0.1
return self._class_counts # type: ignore
return self._class_counts # type: ignore

def to(self, device: str) -> None:
"""Sets what device returned image data will be loaded onto."""
Expand Down Expand Up @@ -434,19 +464,22 @@ def return_data(
method=self.interpolation, # type: ignore
)
elif self.pad:
if not hasattr(self, "_tolerance"):
try:
tolerance = self._tolerance
except AttributeError:
self._tolerance = np.ones(coords[self.axes[0]].shape) * np.max(
list(self.scale.values())
)
tolerance = self._tolerance
data = self.array.reindex(
**coords,
method="nearest",
tolerance=self._tolerance,
tolerance=tolerance,
fill_value=self.pad_value,
)
else:
data = self.array.sel(
**coords, # type: ignore
**coords,
method="nearest",
)
return data
Expand Down
Loading

0 comments on commit c4c07ca

Please sign in to comment.