Skip to content

Commit

Permalink
fix: 🐛 Debug for many datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 29, 2024
1 parent 1674993 commit 2a1134f
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 244 deletions.
3 changes: 0 additions & 3 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def __init__(
self.sampler = sampler
self.is_train = is_train
self.rng = rng
self.construct()

def construct(self):
if self.sampler is None and self.weighted_sampler:
assert isinstance(
self.dataset, CellMapMultiDataset
Expand Down
207 changes: 101 additions & 106 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# %%
import math
import os
from typing import Callable, Dict, Sequence, Optional
from typing import Any, Callable, Dict, Sequence, Optional
import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info
from torch.utils.data import Dataset
import tensorstore
from fibsem_tools.io.core import read, read_xarray
from .image import CellMapImage, EmptyImage


Expand Down Expand Up @@ -113,8 +111,57 @@ def __init__(
self.axis_order = axis_order
self.context = context
self._rng = rng
self.construct()
self.force_has_data = force_has_data
self._bounding_box = None
self._bounding_box_shape = None
self._sampling_box = None
self._sampling_box_shape = None
self._class_counts = None
self._largest_voxel_sizes = None
self._len = None
self._iter_coords = None
self._current_center = None
self._current_spatial_transforms = None
self.input_sources = {}
for array_name, array_info in self.input_arrays.items():
self.input_sources[array_name] = CellMapImage(
self.raw_path,
"raw",
array_info["scale"],
array_info["shape"], # type: ignore
value_transform=self.raw_value_transforms,
context=self.context,
)
self.target_sources = {}
self.has_data = False
for array_name, array_info in self.target_arrays.items():
self.target_sources[array_name] = {}
empty_store = torch.zeros(array_info["shape"]) # type: ignore
for i, label in enumerate(self.classes): # type: ignore
if label in self.classes_with_path:
if isinstance(self.gt_value_transforms, dict):
value_transform: Callable = self.gt_value_transforms[label]
elif isinstance(self.gt_value_transforms, list):
value_transform: Callable = self.gt_value_transforms[i]
else:
value_transform: Callable = self.gt_value_transforms # type: ignore
self.target_sources[array_name][label] = CellMapImage(
self.gt_path_str.format(label=label),
label,
array_info["scale"],
array_info["shape"], # type: ignore
value_transform=value_transform,
context=self.context,
)
if not self.has_data:
self.has_data = (
self.has_data
or self.target_sources[array_name][label].class_counts != 0
)
else:
self.target_sources[array_name][label] = EmptyImage(
label, array_info["shape"], empty_store # type: ignore
)

def __len__(self):
"""Returns the length of the dataset, determined by the number of coordinates that could be sampled as the center for a cube."""
Expand All @@ -123,7 +170,7 @@ def __len__(self):
if self._len is None:
size = 1
for _, (start, stop) in self.sampling_box.items():
size *= stop - start
size *= abs(stop - start)
size /= np.prod(list(self.largest_voxel_sizes.values()))
self._len = int(size)
return self._len
Expand Down Expand Up @@ -167,106 +214,6 @@ def __repr__(self):
"""Returns a string representation of the dataset."""
return f"CellMapDataset(\n\tRaw path: {self.raw_path}\n\tGT path(s): {self.gt_paths}\n\tClasses: {self.classes})"

def to(self, device):
"""Sets the device for the dataset."""
for source in list(self.input_sources.values()) + list(
self.target_sources.values()
):
if isinstance(source, dict):
for source in source.values():
source.to(device)
else:
source.to(device)
return self

def construct(self):
"""Constructs the input and target sources for the dataset."""
self._bounding_box = None
self._bounding_box_shape = None
self._sampling_box = None
self._sampling_box_shape = None
self._class_counts = None
self._largest_voxel_sizes = None
self._len = None
self._iter_coords = None
self._current_center = None
self._current_spatial_transforms = None
self.input_sources = {}
for array_name, array_info in self.input_arrays.items():
self.input_sources[array_name] = CellMapImage(
self.raw_path,
"raw",
array_info["scale"],
array_info["shape"], # type: ignore
value_transform=self.raw_value_transforms,
context=self.context,
)
self.target_sources = {}
self.has_data = False
for array_name, array_info in self.target_arrays.items():
self.target_sources[array_name] = {}
empty_store = torch.zeros(array_info["shape"]) # type: ignore
for i, label in enumerate(self.classes): # type: ignore
if label in self.classes_with_path:
if isinstance(self.gt_value_transforms, dict):
value_transform: Callable = self.gt_value_transforms[label]
elif isinstance(self.gt_value_transforms, list):
value_transform: Callable = self.gt_value_transforms[i]
else:
value_transform: Callable = self.gt_value_transforms # type: ignore
self.target_sources[array_name][label] = CellMapImage(
self.gt_path_str.format(label=label),
label,
array_info["scale"],
array_info["shape"], # type: ignore
value_transform=value_transform,
context=self.context,
)
self.has_data = (
self.has_data
or self.target_sources[array_name][label].class_counts != 0
)
else:
self.target_sources[array_name][label] = EmptyImage(
label, array_info["shape"], empty_store # type: ignore
)

def generate_spatial_transforms(self) -> Optional[dict[str, any]]:
"""Generates spatial transforms for the dataset."""
# TODO: use torch random number generator so accerlerators can synchronize across workers
if self._rng is None:
self._rng = np.random.default_rng()
rng = self._rng

if not self.is_train or self.spatial_transforms is None:
return None
spatial_transforms = {}
for transform, params in self.spatial_transforms.items():
if transform == "mirror":
# input: "mirror": {"axes": {"x": 0.5, "y": 0.5, "z":0.1}}
# output: {"mirror": ["x", "y"]}
spatial_transforms[transform] = []
for axis, prob in params["axes"].items():
if rng.random() < prob:
spatial_transforms[transform].append(axis)
elif transform == "transpose":
# only reorder axes specified in params
# input: "transpose": {"axes": ["x", "z"]}
# output: {"transpose": {"x": 2, "y": 1, "z": 0}}
axes = {axis: i for i, axis in enumerate(self.axis_order)}
shuffled_axes = rng.permutation(
[axes[a] for a in params["axes"]]
) # shuffle indices
shuffled_axes = {
axis: shuffled_axes[i] for i, axis in enumerate(params["axes"])
} # reassign axes
axes.update(shuffled_axes)
spatial_transforms[transform] = axes
else:
raise ValueError(f"Unknown spatial transform: {transform}")
self._current_spatial_transforms = spatial_transforms
return spatial_transforms

@property
def largest_voxel_sizes(self):
"""Returns the largest voxel size of the dataset."""
Expand Down Expand Up @@ -367,6 +314,54 @@ def class_counts(self) -> Dict[str, Dict[str, int]]:
self._class_counts = class_counts
return self._class_counts

def to(self, device):
"""Sets the device for the dataset."""
for source in list(self.input_sources.values()) + list(
self.target_sources.values()
):
if isinstance(source, dict):
for source in source.values():
source.to(device)
else:
source.to(device)
return self

def generate_spatial_transforms(self) -> Optional[dict[str, Any]]:
"""Generates spatial transforms for the dataset."""
# TODO: use torch random number generator so accerlerators can synchronize across workers
if self._rng is None:
self._rng = np.random.default_rng()
rng = self._rng

if not self.is_train or self.spatial_transforms is None:
return None
spatial_transforms = {}
for transform, params in self.spatial_transforms.items():
if transform == "mirror":
# input: "mirror": {"axes": {"x": 0.5, "y": 0.5, "z":0.1}}
# output: {"mirror": ["x", "y"]}
spatial_transforms[transform] = []
for axis, prob in params["axes"].items():
if rng.random() < prob:
spatial_transforms[transform].append(axis)
elif transform == "transpose":
# only reorder axes specified in params
# input: "transpose": {"axes": ["x", "z"]}
# output: {"transpose": {"x": 2, "y": 1, "z": 0}}
axes = {axis: i for i, axis in enumerate(self.axis_order)}
shuffled_axes = rng.permutation(
[axes[a] for a in params["axes"]]
) # shuffle indices
shuffled_axes = {
axis: shuffled_axes[i] for i, axis in enumerate(params["axes"])
} # reassign axes
axes.update(shuffled_axes)
spatial_transforms[transform] = axes
else:
raise ValueError(f"Unknown spatial transform: {transform}")
self._current_spatial_transforms = spatial_transforms
return spatial_transforms


# Example input arrays:
# {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)},
Expand Down
73 changes: 39 additions & 34 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
import os
from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence
import tensorstore
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
Expand All @@ -20,7 +20,7 @@ class CellMapDataSplit:
validate_datasets: Iterable[CellMapDataset]
train_datasets_combined: CellMapMultiDataset
validate_datasets_combined: CellMapMultiDataset
spatial_transforms: Optional[dict[str, any]] = None
spatial_transforms: Optional[dict[str, Any]] = None
raw_value_transforms: Optional[Callable] = None
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
Expand All @@ -36,7 +36,7 @@ def __init__(
datasets: Optional[Dict[str, Iterable[CellMapDataset]]] = None,
dataset_dict: Optional[Mapping[str, Sequence[Dict[str, str]]]] = None,
csv_path: Optional[str] = None,
spatial_transforms: Optional[dict[str, any]] = None,
spatial_transforms: Optional[dict[str, Any]] = None,
raw_value_transforms: Optional[Callable] = None,
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
Expand Down Expand Up @@ -81,7 +81,7 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
}. Defaults to None.
csv_path (Optional[str], optional): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure:
train | validate, raw path, gt path
spatial_transforms (Optional[Sequence[dict[str, any]]], optional): A sequence of dictionaries containing the spatial transformations to apply to the data. The dictionary should have the following structure:
spatial_transforms (Optional[Sequence[dict[str, Any]]], optional): A sequence of dictionaries containing the spatial transformations to apply to the data. The dictionary should have the following structure:
{transform_name: {transform_args}}
Defaults to None.
raw_value_transforms (Optional[Callable], optional): A function to apply to the raw data. Defaults to None. Example is to normalize the raw data.
Expand All @@ -99,9 +99,8 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
self.dataset_dict = {}
elif dataset_dict is not None:
self.dataset_dict = dataset_dict
self.construct(dataset_dict)
elif csv_path is not None:
self.from_csv(csv_path)
self.dataset_dict = self.from_csv(csv_path)
self.spatial_transforms = spatial_transforms
self.raw_value_transforms = raw_value_transforms
self.gt_value_transforms = gt_value_transforms
Expand All @@ -126,29 +125,31 @@ def from_csv(self, csv_path):
}
)

self.dataset_dict = dataset_dict
self.construct(dataset_dict)
return dataset_dict

def construct(self, dataset_dict):
self._class_counts = None
self.train_datasets = []
self.validate_datasets = []
for data_paths in dataset_dict["train"]:
self.train_datasets.append(
CellMapDataset(
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
self.spatial_transforms,
self.raw_value_transforms,
self.gt_value_transforms,
is_train=True,
context=self.context,
force_has_data=self.force_has_data,
try:
self.train_datasets.append(
CellMapDataset(
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
self.spatial_transforms,
self.raw_value_transforms,
self.gt_value_transforms,
is_train=True,
context=self.context,
force_has_data=self.force_has_data,
)
)
)
except ValueError as e:
print(f"Error loading dataset: {e}")

self.train_datasets_combined = CellMapMultiDataset(
self.classes,
Expand All @@ -161,19 +162,23 @@ def construct(self, dataset_dict):

if "validate" in dataset_dict:
for data_paths in dataset_dict["validate"]:
self.validate_datasets.append(
CellMapDataset(
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
gt_value_transforms=self.gt_value_transforms,
is_train=False,
context=self.context,
force_has_data=self.force_has_data,
try:
self.validate_datasets.append(
CellMapDataset(
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
gt_value_transforms=self.gt_value_transforms,
is_train=False,
context=self.context,
force_has_data=self.force_has_data,
)
)
)
except ValueError as e:
print(f"Error loading dataset: {e}")

self.validate_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
Expand Down
Loading

0 comments on commit 2a1134f

Please sign in to comment.