From 2a1134fd24d39f5948e88adf9e8c37c0750d4002 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 29 Apr 2024 17:36:48 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Debug=20for=20many=20data?= =?UTF-8?q?sets?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/dataloader.py | 3 - src/cellmap_data/dataset.py | 207 +++++++++++++++--------------- src/cellmap_data/datasplit.py | 73 ++++++----- src/cellmap_data/image.py | 209 +++++++++++++++++-------------- src/cellmap_data/multidataset.py | 5 +- 5 files changed, 253 insertions(+), 244 deletions(-) diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index bb5a0e0..b87a6f3 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -39,9 +39,6 @@ def __init__( self.sampler = sampler self.is_train = is_train self.rng = rng - self.construct() - - def construct(self): if self.sampler is None and self.weighted_sampler: assert isinstance( self.dataset, CellMapMultiDataset diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 91b927a..edf93a3 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -1,12 +1,10 @@ # %% -import math import os -from typing import Callable, Dict, Sequence, Optional +from typing import Any, Callable, Dict, Sequence, Optional import numpy as np import torch -from torch.utils.data import Dataset, get_worker_info +from torch.utils.data import Dataset import tensorstore -from fibsem_tools.io.core import read, read_xarray from .image import CellMapImage, EmptyImage @@ -113,8 +111,57 @@ def __init__( self.axis_order = axis_order self.context = context self._rng = rng - self.construct() self.force_has_data = force_has_data + self._bounding_box = None + self._bounding_box_shape = None + self._sampling_box = None + self._sampling_box_shape = None + self._class_counts = None + self._largest_voxel_sizes = None + self._len = None + self._iter_coords = None + self._current_center = None + self._current_spatial_transforms = None + self.input_sources = {} + for array_name, array_info in self.input_arrays.items(): + self.input_sources[array_name] = CellMapImage( + self.raw_path, + "raw", + array_info["scale"], + array_info["shape"], # type: ignore + value_transform=self.raw_value_transforms, + context=self.context, + ) + self.target_sources = {} + self.has_data = False + for array_name, array_info in self.target_arrays.items(): + self.target_sources[array_name] = {} + empty_store = torch.zeros(array_info["shape"]) # type: ignore + for i, label in enumerate(self.classes): # type: ignore + if label in self.classes_with_path: + if isinstance(self.gt_value_transforms, dict): + value_transform: Callable = self.gt_value_transforms[label] + elif isinstance(self.gt_value_transforms, list): + value_transform: Callable = self.gt_value_transforms[i] + else: + value_transform: Callable = self.gt_value_transforms # type: ignore + self.target_sources[array_name][label] = CellMapImage( + self.gt_path_str.format(label=label), + label, + array_info["scale"], + array_info["shape"], # type: ignore + value_transform=value_transform, + context=self.context, + ) + if not self.has_data: + self.has_data = ( + self.has_data + or self.target_sources[array_name][label].class_counts != 0 + ) + else: + self.target_sources[array_name][label] = EmptyImage( + label, array_info["shape"], empty_store # type: ignore + ) def __len__(self): """Returns the length of the dataset, determined by the number of coordinates that could be sampled as the center for a cube.""" @@ -123,7 +170,7 @@ def __len__(self): if self._len is None: size = 1 for _, (start, stop) in self.sampling_box.items(): - size *= stop - start + size *= abs(stop - start) size /= np.prod(list(self.largest_voxel_sizes.values())) self._len = int(size) return self._len @@ -167,106 +214,6 @@ def __repr__(self): """Returns a string representation of the dataset.""" return f"CellMapDataset(\n\tRaw path: {self.raw_path}\n\tGT path(s): {self.gt_paths}\n\tClasses: {self.classes})" - def to(self, device): - """Sets the device for the dataset.""" - for source in list(self.input_sources.values()) + list( - self.target_sources.values() - ): - if isinstance(source, dict): - for source in source.values(): - source.to(device) - else: - source.to(device) - return self - - def construct(self): - """Constructs the input and target sources for the dataset.""" - self._bounding_box = None - self._bounding_box_shape = None - self._sampling_box = None - self._sampling_box_shape = None - self._class_counts = None - self._largest_voxel_sizes = None - self._len = None - self._iter_coords = None - self._current_center = None - self._current_spatial_transforms = None - self.input_sources = {} - for array_name, array_info in self.input_arrays.items(): - self.input_sources[array_name] = CellMapImage( - self.raw_path, - "raw", - array_info["scale"], - array_info["shape"], # type: ignore - value_transform=self.raw_value_transforms, - context=self.context, - ) - self.target_sources = {} - self.has_data = False - for array_name, array_info in self.target_arrays.items(): - self.target_sources[array_name] = {} - empty_store = torch.zeros(array_info["shape"]) # type: ignore - for i, label in enumerate(self.classes): # type: ignore - if label in self.classes_with_path: - if isinstance(self.gt_value_transforms, dict): - value_transform: Callable = self.gt_value_transforms[label] - elif isinstance(self.gt_value_transforms, list): - value_transform: Callable = self.gt_value_transforms[i] - else: - value_transform: Callable = self.gt_value_transforms # type: ignore - self.target_sources[array_name][label] = CellMapImage( - self.gt_path_str.format(label=label), - label, - array_info["scale"], - array_info["shape"], # type: ignore - value_transform=value_transform, - context=self.context, - ) - self.has_data = ( - self.has_data - or self.target_sources[array_name][label].class_counts != 0 - ) - else: - self.target_sources[array_name][label] = EmptyImage( - label, array_info["shape"], empty_store # type: ignore - ) - - def generate_spatial_transforms(self) -> Optional[dict[str, any]]: - """Generates spatial transforms for the dataset.""" - # TODO: use torch random number generator so accerlerators can synchronize across workers - if self._rng is None: - self._rng = np.random.default_rng() - rng = self._rng - - if not self.is_train or self.spatial_transforms is None: - return None - spatial_transforms = {} - for transform, params in self.spatial_transforms.items(): - if transform == "mirror": - # input: "mirror": {"axes": {"x": 0.5, "y": 0.5, "z":0.1}} - # output: {"mirror": ["x", "y"]} - spatial_transforms[transform] = [] - for axis, prob in params["axes"].items(): - if rng.random() < prob: - spatial_transforms[transform].append(axis) - elif transform == "transpose": - # only reorder axes specified in params - # input: "transpose": {"axes": ["x", "z"]} - # output: {"transpose": {"x": 2, "y": 1, "z": 0}} - axes = {axis: i for i, axis in enumerate(self.axis_order)} - shuffled_axes = rng.permutation( - [axes[a] for a in params["axes"]] - ) # shuffle indices - shuffled_axes = { - axis: shuffled_axes[i] for i, axis in enumerate(params["axes"]) - } # reassign axes - axes.update(shuffled_axes) - spatial_transforms[transform] = axes - else: - raise ValueError(f"Unknown spatial transform: {transform}") - self._current_spatial_transforms = spatial_transforms - return spatial_transforms - @property def largest_voxel_sizes(self): """Returns the largest voxel size of the dataset.""" @@ -367,6 +314,54 @@ def class_counts(self) -> Dict[str, Dict[str, int]]: self._class_counts = class_counts return self._class_counts + def to(self, device): + """Sets the device for the dataset.""" + for source in list(self.input_sources.values()) + list( + self.target_sources.values() + ): + if isinstance(source, dict): + for source in source.values(): + source.to(device) + else: + source.to(device) + return self + + def generate_spatial_transforms(self) -> Optional[dict[str, Any]]: + """Generates spatial transforms for the dataset.""" + # TODO: use torch random number generator so accerlerators can synchronize across workers + if self._rng is None: + self._rng = np.random.default_rng() + rng = self._rng + + if not self.is_train or self.spatial_transforms is None: + return None + spatial_transforms = {} + for transform, params in self.spatial_transforms.items(): + if transform == "mirror": + # input: "mirror": {"axes": {"x": 0.5, "y": 0.5, "z":0.1}} + # output: {"mirror": ["x", "y"]} + spatial_transforms[transform] = [] + for axis, prob in params["axes"].items(): + if rng.random() < prob: + spatial_transforms[transform].append(axis) + elif transform == "transpose": + # only reorder axes specified in params + # input: "transpose": {"axes": ["x", "z"]} + # output: {"transpose": {"x": 2, "y": 1, "z": 0}} + axes = {axis: i for i, axis in enumerate(self.axis_order)} + shuffled_axes = rng.permutation( + [axes[a] for a in params["axes"]] + ) # shuffle indices + shuffled_axes = { + axis: shuffled_axes[i] for i, axis in enumerate(params["axes"]) + } # reassign axes + axes.update(shuffled_axes) + spatial_transforms[transform] = axes + else: + raise ValueError(f"Unknown spatial transform: {transform}") + self._current_spatial_transforms = spatial_transforms + return spatial_transforms + # Example input arrays: # {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)}, diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index 1a2b83e..751bff2 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -1,6 +1,6 @@ import csv import os -from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence import tensorstore from .dataset import CellMapDataset from .multidataset import CellMapMultiDataset @@ -20,7 +20,7 @@ class CellMapDataSplit: validate_datasets: Iterable[CellMapDataset] train_datasets_combined: CellMapMultiDataset validate_datasets_combined: CellMapMultiDataset - spatial_transforms: Optional[dict[str, any]] = None + spatial_transforms: Optional[dict[str, Any]] = None raw_value_transforms: Optional[Callable] = None gt_value_transforms: Optional[ Callable | Sequence[Callable] | dict[str, Callable] @@ -36,7 +36,7 @@ def __init__( datasets: Optional[Dict[str, Iterable[CellMapDataset]]] = None, dataset_dict: Optional[Mapping[str, Sequence[Dict[str, str]]]] = None, csv_path: Optional[str] = None, - spatial_transforms: Optional[dict[str, any]] = None, + spatial_transforms: Optional[dict[str, Any]] = None, raw_value_transforms: Optional[Callable] = None, gt_value_transforms: Optional[ Callable | Sequence[Callable] | dict[str, Callable] @@ -81,7 +81,7 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso }. Defaults to None. csv_path (Optional[str], optional): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure: train | validate, raw path, gt path - spatial_transforms (Optional[Sequence[dict[str, any]]], optional): A sequence of dictionaries containing the spatial transformations to apply to the data. The dictionary should have the following structure: + spatial_transforms (Optional[Sequence[dict[str, Any]]], optional): A sequence of dictionaries containing the spatial transformations to apply to the data. The dictionary should have the following structure: {transform_name: {transform_args}} Defaults to None. raw_value_transforms (Optional[Callable], optional): A function to apply to the raw data. Defaults to None. Example is to normalize the raw data. @@ -99,9 +99,8 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso self.dataset_dict = {} elif dataset_dict is not None: self.dataset_dict = dataset_dict - self.construct(dataset_dict) elif csv_path is not None: - self.from_csv(csv_path) + self.dataset_dict = self.from_csv(csv_path) self.spatial_transforms = spatial_transforms self.raw_value_transforms = raw_value_transforms self.gt_value_transforms = gt_value_transforms @@ -126,29 +125,31 @@ def from_csv(self, csv_path): } ) - self.dataset_dict = dataset_dict - self.construct(dataset_dict) + return dataset_dict def construct(self, dataset_dict): self._class_counts = None self.train_datasets = [] self.validate_datasets = [] for data_paths in dataset_dict["train"]: - self.train_datasets.append( - CellMapDataset( - data_paths["raw"], - data_paths["gt"], - self.classes, - self.input_arrays, - self.target_arrays, - self.spatial_transforms, - self.raw_value_transforms, - self.gt_value_transforms, - is_train=True, - context=self.context, - force_has_data=self.force_has_data, + try: + self.train_datasets.append( + CellMapDataset( + data_paths["raw"], + data_paths["gt"], + self.classes, + self.input_arrays, + self.target_arrays, + self.spatial_transforms, + self.raw_value_transforms, + self.gt_value_transforms, + is_train=True, + context=self.context, + force_has_data=self.force_has_data, + ) ) - ) + except ValueError as e: + print(f"Error loading dataset: {e}") self.train_datasets_combined = CellMapMultiDataset( self.classes, @@ -161,19 +162,23 @@ def construct(self, dataset_dict): if "validate" in dataset_dict: for data_paths in dataset_dict["validate"]: - self.validate_datasets.append( - CellMapDataset( - data_paths["raw"], - data_paths["gt"], - self.classes, - self.input_arrays, - self.target_arrays, - gt_value_transforms=self.gt_value_transforms, - is_train=False, - context=self.context, - force_has_data=self.force_has_data, + try: + self.validate_datasets.append( + CellMapDataset( + data_paths["raw"], + data_paths["gt"], + self.classes, + self.input_arrays, + self.target_arrays, + gt_value_transforms=self.gt_value_transforms, + is_train=False, + context=self.context, + force_has_data=self.force_has_data, + ) ) - ) + except ValueError as e: + print(f"Error loading dataset: {e}") + self.validate_datasets_combined = CellMapMultiDataset( self.classes, self.input_arrays, diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index defb898..1f33486 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Iterable, Optional, Sequence +from typing import Any, Callable, Iterable, Optional, Sequence import torch from fibsem_tools.io.core import read_xarray import xarray @@ -12,12 +12,12 @@ class CellMapImage: path: str - translation: dict[str, float] + _translation: dict[str, float] scale: dict[str, float] output_shape: dict[str, int] output_size: dict[str, float] label_class: str - array: xarray.DataArray + _array: xarray.DataArray axes: str | Sequence[str] post_image_transforms: Sequence[str] = ["transpose"] value_transform: Optional[Callable] @@ -66,7 +66,18 @@ def __init__( self.axes = axis_order[: len(target_voxel_shape)] self.value_transform = value_transform self.context = context - self.construct() + self._array = None + self._bounding_box = None + self._sampling_box = None + self._class_counts = None + self._current_spatial_transforms = None + self._last_coords = None + self._translation = None + self._original_scale = None + self.device = "cuda" if torch.cuda.is_available() else "cpu" + # self.xs = self.array.coords["x"] + # self.ys = self.array.coords["y"] + # self.zs = self.array.coords["z"] def __getitem__(self, center: dict[str, float]) -> torch.Tensor: """Returns image data centered around the given point, based on the scale and shape of the target output image.""" @@ -95,36 +106,98 @@ def __getitem__(self, center: dict[str, float]) -> torch.Tensor: data = self.value_transform(data) return data.to(self.device) + def __repr__(self) -> str: + return f"CellMapImage({self.path})" + + @property + def array(self) -> xarray.DataArray: + if self._array is None: + self.group = read_xarray(self.path) + # Find correct multiscale level based on target scale + self.scale_level = self.find_level(self.scale) + self.array_path = os.path.join(self.path, self.scale_level) + # Construct an xarray with Tensorstore backend + ds = read_xarray(self.array_path) + spec = xt._zarr_spec_from_path(self.array_path) + array_future = tensorstore.open( # type: ignore + spec, read=True, write=False, context=self.context + ) + array = array_future.result() + new_data = xt._TensorStoreAdapter(array) + self._array = ds.copy(data=new_data) # type: ignore + return self._array # type: ignore + + @property + def translation(self) -> dict[str, float]: + """Returns the translation of the image.""" + if self._translation is None: + # Get the translation of the image + self._translation = {c: min(self.array.coords[c].values) for c in self.axes} + return self._translation + + @property + def original_scale(self) -> dict[str, Any] | None: + """Returns the original scale of the image.""" + if self._original_scale is None: + # Get the original scale of the image from poorly formatted metadata + for level_data in self.group.attrs["multiscales"][0]["datasets"]: + if level_data["path"] == self.scale_level: + level_data = level_data["coordinateTransformations"] + for transform in level_data: + if transform["type"] == "scale": + self._original_scale = { + c: transform["scale"][i] + for i, c in enumerate(self.axes) + } + break + return self._original_scale + + @property + def bounding_box(self) -> dict[str, list[float]]: + """Returns the bounding box of the dataset in world units.""" + if self._bounding_box is None: + self._bounding_box = { + c: [self.translation[c], max(self.array.coords[c].values)] + for c in self.axes + } + return self._bounding_box + + @property + def sampling_box(self) -> dict[str, list[float]]: + """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box), in world units.""" + if self._sampling_box is None: + self._sampling_box = {} + output_padding = {c: (s / 2) for c, s in self.output_size.items()} + for c, (start, stop) in self.bounding_box.items(): + self._sampling_box[c] = [ + start + output_padding[c], + stop - output_padding[c], + ] + return self._sampling_box + + @property + def class_counts(self) -> int: + """Returns the number of pixels for the contained class in the ground truth data, normalized by the resolution.""" + if self._class_counts is None: + # Get from cellmap-schemas metadata, then normalize by resolution + try: + group = zarr.open(self.path, mode="r") + annotation_group = AnnotationGroup.from_zarr(group) # type: ignore + self._class_counts = ( + np.prod(self.array.shape) + - annotation_group.members[ + self.scale_level + ].attrs.cellmap.annotation.complement_counts["absent"] + ) + except Exception as e: + print(e) + self._class_counts = 0 + return self._class_counts + def to(self, device: str) -> None: """Sets what device returned image data will be loaded onto.""" self.device = device - # TODO: move into __init__ - def construct(self): - self._bounding_box = None - self._sampling_box = None - self._class_counts = None - self._current_spatial_transforms = None - self._last_coords = None - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.group = read_xarray(self.path) - # Find correct multiscale level based on target scale - self.scale_level = self.find_level(self.scale) - self.array_path = os.path.join(self.path, self.scale_level) - # Construct an xarray with Tensorstore backend - ds = read_xarray(self.array_path) - spec = xt._zarr_spec_from_path(self.array_path) - array_future = tensorstore.open( # type: ignore - spec, read=True, write=False, context=self.context - ) - array = array_future.result() - new_data = xt._TensorStoreAdapter(array) - self.array = ds.copy(data=new_data) # type: ignore - self.get_spatial_metadata() - # self.xs = self.array.coords["x"] - # self.ys = self.array.coords["y"] - # self.zs = self.array.coords["z"] - def find_level(self, target_scale: dict[str, float]) -> str: """Finds the multiscale level that is closest to the target scale.""" # Get the order of axes in the image @@ -133,7 +206,7 @@ def find_level(self, target_scale: dict[str, float]) -> str: if axis["type"] == "space": axes.append(axis["name"]) - last_path: str = "" + last_path: str | None = None scale = {} for level in self.group.attrs["multiscales"][0]["datasets"]: for transform in level["coordinateTransformations"]: @@ -149,23 +222,7 @@ def find_level(self, target_scale: dict[str, float]) -> str: last_path = level["path"] return last_path - def get_spatial_metadata(self): - """Gets the spatial metadata for the image.""" - # Get the translation of the image - self.translation = {c: min(self.array.coords[c].values) for c in self.axes} - - # Get the original scale of the image from poorly formatted metadata - for level_data in self.group.attrs["multiscales"][0]["datasets"]: - if level_data["path"] == self.scale_level: - level_data = level_data["coordinateTransformations"] - for transform in level_data: - if transform["type"] == "scale": - self.original_scale = { - c: transform["scale"][i] for i, c in enumerate(self.axes) - } - break - - def set_spatial_transforms(self, transforms: dict[str, any] | None): + def set_spatial_transforms(self, transforms: dict[str, Any] | None): """Sets spatial transformations for the image data.""" self._current_spatial_transforms = transforms @@ -213,48 +270,6 @@ def return_data(self, coords: dict[str, Sequence[float]]): return data - @property - def bounding_box(self) -> dict[str, list[float]]: - """Returns the bounding box of the dataset in world units.""" - if self._bounding_box is None: - self._bounding_box = { - c: [self.translation[c], max(self.array.coords[c].values)] - for c in self.axes - } - return self._bounding_box - - @property - def sampling_box(self) -> dict[str, list[float]]: - """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box), in world units.""" - if self._sampling_box is None: - self._sampling_box = {} - output_padding = {c: (s / 2) for c, s in self.output_size.items()} - for c, (start, stop) in self.bounding_box.items(): - self._sampling_box[c] = [ - start + output_padding[c], - stop - output_padding[c], - ] - return self._sampling_box - - @property - def class_counts(self) -> int: - """Returns the number of pixels for the contained class in the ground truth data, normalized by the resolution.""" - if self._class_counts is None: - # Get from cellmap-schemas metadata, then normalize by resolution - try: - group = zarr.open(self.path, mode="r") - annotation_group = AnnotationGroup.from_zarr(group) # type: ignore - self._class_counts = ( - np.prod(self.array.shape) - - annotation_group.members[ - self.scale_level - ].attrs.cellmap.annotation.complement_counts["absent"] - ) - except Exception as e: - print(e) - self._class_counts = -1 - return self._class_counts - class EmptyImage: label_class: str @@ -291,13 +306,6 @@ def __getitem__(self, center: dict[str, float]) -> torch.Tensor: """Returns image data centered around the given point, based on the scale and shape of the target output image.""" return self.store - def to(self, device: str) -> None: - """Moves the image data to the given device.""" - self.store = self.store.to(device) - - def set_spatial_transforms(self, transforms: dict[str, any] | None): - pass - @property def bounding_box(self) -> None: """Returns the bounding box of the dataset.""" @@ -312,3 +320,10 @@ def sampling_box(self) -> None: def class_counts(self) -> int: """Returns the number of pixels for the contained class in the ground truth data, normalized by the resolution.""" return self._class_counts + + def to(self, device: str) -> None: + """Moves the image data to the given device.""" + self.store = self.store.to(device) + + def set_spatial_transforms(self, transforms: dict[str, Any] | None): + pass diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index 0a80350..7f926c0 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -30,7 +30,7 @@ def __init__( self.target_arrays = target_arrays self.classes = classes self.datasets = datasets - self.construct() + self._weighted_sampler = None def __repr__(self) -> str: out_string = f"CellMapMultiDataset([" @@ -44,9 +44,6 @@ def to(self, device: str): dataset.to(device) return self - def construct(self): - self._weighted_sampler = None - def weighted_sampler( self, batch_size: int = 1, rng: Optional[torch.Generator] = None ):