Skip to content

Commit

Permalink
fix: 🐛 Fix for validation Subsets.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 3, 2024
1 parent 73ffdf7 commit 7422143
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 37 deletions.
26 changes: 19 additions & 7 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import torch
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler, Sampler
from torch.utils.data import DataLoader, Sampler, Subset
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset

from typing import Callable, Iterable, Optional
from typing import Iterable, Optional


class CellMapDataLoader:
# TODO: docstring corrections
"""This subclasses PyTorch DataLoader to load CellMap data for training. It maintains the same API as the DataLoader class. This includes applying augmentations to the data and returning the data in the correct format for training, such as generating the target arrays (e.g. signed distance transform of labels). It retrieves raw and groundtruth data from a CellMapDataSplit object, which is a subclass of PyTorch Dataset. Training and validation data are split using the CellMapDataSplit object, and separate dataloaders are maintained as `train_loader` and `validate_loader` respectively."""

dataset: CellMapMultiDataset | CellMapDataset
dataset: (
CellMapMultiDataset
| CellMapDataset
| Subset[CellMapDataset | CellMapMultiDataset]
)
classes: Iterable[str]
loader = DataLoader
batch_size: int
Expand All @@ -22,7 +26,11 @@ class CellMapDataLoader:

def __init__(
self,
dataset: CellMapMultiDataset | CellMapDataset,
dataset: (
CellMapMultiDataset
| CellMapDataset
| Subset[CellMapDataset | CellMapMultiDataset]
),
classes: Iterable[str],
batch_size: int = 1,
num_workers: int = 0,
Expand All @@ -44,10 +52,14 @@ def __init__(
self.dataset, CellMapMultiDataset
), "Weighted sampler only relevant for CellMapMultiDataset"
self.sampler = self.dataset.weighted_sampler(self.batch_size, self.rng)
if torch.cuda.is_available():
if isinstance(self.dataset, Subset):
self.dataset.dataset.to("cuda") # type: ignore
else:
self.dataset.to("cuda")
kwargs = {
"dataset": (
self.dataset.to("cuda") if torch.cuda.is_available() else self.dataset
),
"dataset": self.dataset,
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
Expand Down
65 changes: 46 additions & 19 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .image import CellMapImage, EmptyImage


def split_gt_path(path: str) -> tuple[str, list[str]]:
def split_target_path(path: str) -> tuple[str, list[str]]:
"""Splits a path to groundtruth data into the main path string, and the classes supplied for it."""
try:
path_prefix, path_rem = path.split("[")
Expand All @@ -26,15 +26,17 @@ class CellMapDataset(Dataset):
"""This subclasses PyTorch Dataset to load CellMap data for training. It maintains the same API as the Dataset class. Importantly, it maintains information about and handles for the sources for raw and groundtruth data. This information includes the path to the data, the classes for segmentation, and the arrays to input to the network and use as targets for the network. The dataset constructs the sources for the raw and groundtruth data, and retrieves the data from the sources. The dataset also provides methods to get the number of pixels for each class in the ground truth data, normalized by the resolution. Additionally, random crops of the data can be generated for training, because the CellMapDataset maintains information about the extents of its source arrays. This object additionally combines images for different classes into a single output array, which is useful for training segmentation networks."""

raw_path: str
gt_path: str
target_path: str
classes: Sequence[str]
input_arrays: dict[str, dict[str, Sequence[int | float]]]
target_arrays: dict[str, dict[str, Sequence[int | float]]]
input_sources: dict[str, CellMapImage]
target_sources: dict[str, dict[str, CellMapImage | EmptyImage]]
spatial_transforms: Optional[dict[str, any]] # type: ignore
raw_value_transforms: Optional[Callable]
gt_value_transforms: Optional[Callable | Sequence[Callable] | dict[str, Callable]]
target_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
]
has_data: bool
is_train: bool
axis_order: str
Expand All @@ -51,13 +53,13 @@ class CellMapDataset(Dataset):
def __init__(
self,
raw_path: str, # TODO: Switch "raw_path" to "input_path"
gt_path: str,
target_path: str,
classes: Sequence[str],
input_arrays: dict[str, dict[str, Sequence[int | float]]],
target_arrays: dict[str, dict[str, Sequence[int | float]]],
spatial_transforms: Optional[dict[str, any]] = None, # type: ignore
raw_value_transforms: Optional[Callable] = None,
gt_value_transforms: Optional[
target_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None,
is_train: bool = False,
Expand All @@ -70,7 +72,7 @@ def __init__(
Args:
raw_path (str): The path to the raw data.
gt_path (str): The path to the ground truth data.
target_path (str): The path to the ground truth data.
classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros.
input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:
{
Expand All @@ -92,21 +94,21 @@ def __init__(
{transform_name: {transform_args}}
Defaults to None.
raw_value_transforms (Optional[Callable], optional): A function to apply to the raw data. Defaults to None. Example is to normalize the raw data.
gt_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
target_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
is_train (bool, optional): Whether the dataset is for training. Defaults to False.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
rng (Optional[np.random.Generator], optional): A random number generator. Defaults to None.
force_has_data (bool, optional): Whether to force the dataset to report that it has data. Defaults to False.
"""
self.raw_path = raw_path
self.gt_paths = gt_path
self.gt_path_str, self.classes_with_path = split_gt_path(gt_path)
self.target_paths = target_path
self.target_path_str, self.classes_with_path = split_target_path(target_path)
self.classes = classes
self.input_arrays = input_arrays
self.target_arrays = target_arrays
self.spatial_transforms = spatial_transforms
self.raw_value_transforms = raw_value_transforms
self.gt_value_transforms = gt_value_transforms
self.target_value_transforms = target_value_transforms
self.is_train = is_train
self.axis_order = axis_order
self.context = context
Expand Down Expand Up @@ -139,14 +141,14 @@ def __init__(
empty_store = torch.zeros(array_info["shape"]) # type: ignore
for i, label in enumerate(self.classes): # type: ignore
if label in self.classes_with_path:
if isinstance(self.gt_value_transforms, dict):
value_transform: Callable = self.gt_value_transforms[label]
elif isinstance(self.gt_value_transforms, list):
value_transform: Callable = self.gt_value_transforms[i]
if isinstance(self.target_value_transforms, dict):
value_transform: Callable = self.target_value_transforms[label]
elif isinstance(self.target_value_transforms, list):
value_transform: Callable = self.target_value_transforms[i]
else:
value_transform: Callable = self.gt_value_transforms # type: ignore
value_transform: Callable = self.target_value_transforms # type: ignore
self.target_sources[array_name][label] = CellMapImage(
self.gt_path_str.format(label=label),
self.target_path_str.format(label=label),
label,
array_info["scale"],
array_info["shape"], # type: ignore
Expand Down Expand Up @@ -184,6 +186,7 @@ def __getitem__(self, idx):
c: center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0]
for i, c in enumerate(self.axis_order)
}
self._current_idx = idx
self._current_center = center
spatial_transforms = self.generate_spatial_transforms()
outputs = {}
Expand All @@ -210,7 +213,7 @@ def __getitem__(self, idx):

def __repr__(self):
"""Returns a string representation of the dataset."""
return f"CellMapDataset(\n\tRaw path: {self.raw_path}\n\tGT path(s): {self.gt_paths}\n\tClasses: {self.classes})"
return f"CellMapDataset(\n\tRaw path: {self.raw_path}\n\tGT path(s): {self.target_paths}\n\tClasses: {self.classes})"

@property
def largest_voxel_sizes(self):
Expand Down Expand Up @@ -314,6 +317,16 @@ def _get_box(
current_box[c][1] = min(current_box[c][1], stop)
return current_box

def verify(self):
"""Verifies that the dataset is valid."""
# TODO: make more robust
try:
length = len(self)
return True
except Exception as e:
# print(e)
return False

def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]:
# TODO: ADD TEST
"""Returns the indices of the dataset that will tile the dataset according to the chunk_size."""
Expand All @@ -326,7 +339,7 @@ def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]:
# Generate linear indices by unraveling all combinations of axes indices
for i in np.ndindex(*[len(indices_dict[c]) for c in self.axis_order]):
index = [indices_dict[c][j] for c, j in zip(self.axis_order, i)]
index = np.ravel_multi_index(index, self.sampling_box_shape)
index = np.ravel_multi_index(index, list(self.sampling_box_shape.values()))
indices.append(index)

return indices
Expand All @@ -335,7 +348,7 @@ def get_validation_indices(self) -> Sequence[int]:
"""Returns the indices of the dataset that will tile the dataset for validation."""
chunk_size = {}
for c, size in self.bounding_box_shape.items():
chunk_size[c] = np.ceil(size - self.sampling_box_shape[c], dtype=int)
chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int)
return self.get_indices(chunk_size)

def to(self, device):
Expand Down Expand Up @@ -386,6 +399,20 @@ def generate_spatial_transforms(self) -> Optional[dict[str, Any]]:
self._current_spatial_transforms = spatial_transforms
return spatial_transforms

def set_raw_value_transforms(self, transforms: Callable):
"""Sets the raw value transforms for the dataset."""
self.raw_value_transforms = transforms
for source in self.input_sources.values():
source.value_transform = transforms

def set_target_value_transforms(self, transforms: Callable):
"""Sets the ground truth value transforms for the dataset."""
self.target_value_transforms = transforms
for sources in self.target_sources.values():
for source in sources.values():
if isinstance(source, CellMapImage):
source.value_transform = transforms


# Example input arrays:
# {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)},
Expand Down
47 changes: 40 additions & 7 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CellMapDataSplit:
validation_datasets: Sequence[CellMapDataset]
spatial_transforms: Optional[dict[str, Any]] = None
raw_value_transforms: Optional[Callable] = None
gt_value_transforms: Optional[
target_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None
force_has_data: bool = False
Expand All @@ -37,7 +37,7 @@ def __init__(
csv_path: Optional[str] = None,
spatial_transforms: Optional[dict[str, Any]] = None,
raw_value_transforms: Optional[Callable] = None,
gt_value_transforms: Optional[
target_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None,
force_has_data: bool = False,
Expand Down Expand Up @@ -84,7 +84,7 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
{transform_name: {transform_args}}
Defaults to None.
raw_value_transforms (Optional[Callable], optional): A function to apply to the raw data. Defaults to None. Example is to normalize the raw data.
gt_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
target_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
force_has_data (bool, optional): Whether to force the dataset to have data. Defaults to False.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
"""
Expand All @@ -106,13 +106,14 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
self.dataset_dict = self.from_csv(csv_path)
self.spatial_transforms = spatial_transforms
self.raw_value_transforms = raw_value_transforms
self.gt_value_transforms = gt_value_transforms
self.target_value_transforms = target_value_transforms
self.context = context
if self.dataset_dict is not None:
self.construct(self.dataset_dict)
self.verify_datasets()

def __repr__(self):
return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.raw_value_transforms}\n\tGT value transforms: {self.gt_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})"
return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})"

def from_csv(self, csv_path):
# Load file data from csv file
Expand Down Expand Up @@ -195,7 +196,7 @@ def construct(self, dataset_dict):
self.target_arrays,
self.spatial_transforms,
self.raw_value_transforms,
self.gt_value_transforms,
self.target_value_transforms,
is_train=True,
context=self.context,
force_has_data=self.force_has_data,
Expand All @@ -218,7 +219,7 @@ def construct(self, dataset_dict):
self.classes,
self.input_arrays,
self.target_arrays,
gt_value_transforms=self.gt_value_transforms,
target_value_transforms=self.target_value_transforms,
is_train=False,
context=self.context,
force_has_data=self.force_has_data,
Expand All @@ -229,6 +230,38 @@ def construct(self, dataset_dict):

self.datasets["validate"] = self.validation_datasets

def verify_datasets(self):
verified_datasets = []
for ds in self.train_datasets:
if ds.verify():
verified_datasets.append(ds)
self.train_datasets = verified_datasets

verified_datasets = []
for ds in self.validation_datasets:
if ds.verify():
verified_datasets.append(ds)
self.validation_datasets = verified_datasets

def set_raw_value_transforms(self, transforms: Callable):
"""Sets the raw value transforms for each dataset in the training multi-dataset."""
for dataset in self.train_datasets:
dataset.set_raw_value_transforms(transforms)

def set_target_value_transforms(self, transforms: Callable):
"""Sets the target value transforms for each dataset in the multi-datasets."""
for dataset in self.train_datasets:
dataset.set_target_value_transforms(transforms)
if hasattr(self, "_train_datasets_combined"):
self._train_datasets_combined.set_target_value_transforms(transforms)

for dataset in self.validation_datasets:
dataset.set_target_value_transforms(transforms)
if hasattr(self, "_validation_datasets_combined"):
self._validation_datasets_combined.set_target_value_transforms(transforms)
if hasattr(self, "_validation_blocks"):
self._validation_blocks.set_target_value_transforms(transforms)


# Example input arrays:
# {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)},
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def class_counts(self) -> int:
- annotation_group.members[
self.scale_level
].attrs.cellmap.annotation.complement_counts["absent"]
)
) / np.prod(list(self.scale.values()))
except Exception as e:
print(e)
self._class_counts = 0
Expand Down
Loading

0 comments on commit 7422143

Please sign in to comment.