Skip to content

Commit

Permalink
refactor: 🚚 Make get() functions into properties, and bug fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 9, 2024
1 parent 1f6674a commit 176d568
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 107 deletions.
57 changes: 32 additions & 25 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,30 @@ def sampling_box_shape(self):
self._sampling_box_shape = sampling_box_shape
return self._sampling_box_shape

@property
def class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.classes) > 1:
class_counts = {c: 0 for c in self.classes}
class_count_sum = 0
for c in self.classes:
class_counts[c] += self.class_counts["totals"][c]
class_count_sum += self.class_counts["totals"][c]

class_weights = {
c: (
1 - (class_counts[c] / class_count_sum)
if class_counts[c] != class_count_sum
else 0.1
)
for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights

@property
def class_counts(self) -> Dict[str, Dict[str, int]]:
"""Returns the number of pixels for each class in the ground truth data, normalized by the resolution."""
Expand All @@ -319,6 +343,14 @@ def class_counts(self) -> Dict[str, Dict[str, int]]:
self._class_counts = class_counts
return self._class_counts

@property
def validation_indices(self) -> Sequence[int]:
"""Returns the indices of the dataset that will tile the dataset for validation."""
chunk_size = {}
for c, size in self.bounding_box_shape.items():
chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int)
return self.get_indices(chunk_size)

def _get_box(
self, source_box: dict[str, list[int]], current_box: dict[str, list[int]]
) -> dict[str, list[int]]:
Expand Down Expand Up @@ -356,13 +388,6 @@ def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]:

return indices

def get_validation_indices(self) -> Sequence[int]:
"""Returns the indices of the dataset that will tile the dataset for validation."""
chunk_size = {}
for c, size in self.bounding_box_shape.items():
chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int)
return self.get_indices(chunk_size)

def to(self, device):
"""Sets the device for the dataset."""
for source in list(self.input_sources.values()) + list(
Expand Down Expand Up @@ -425,24 +450,6 @@ def set_target_value_transforms(self, transforms: Callable):
if isinstance(source, CellMapImage):
source.value_transform = transforms

def get_class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.classes) > 1:
class_counts = {c: 0 for c in self.classes}
class_count_sum = 0
for c in self.classes:
class_counts[c] += self.class_counts["totals"][c]
class_count_sum += self.class_counts["totals"][c]

class_weights = {
c: 1 - (class_counts[c] / class_count_sum) for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights


# Example input arrays:
# {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)},
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def validation_blocks(self):
if not hasattr(self, "_validation_blocks"):
self._validation_blocks = CellMapSubset(
self.validation_datasets_combined,
self.validation_datasets_combined.get_validation_indices(),
self.validation_datasets_combined.validation_indices,
)
return self._validation_blocks

Expand Down
130 changes: 69 additions & 61 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class CellMapMultiDataset(ConcatDataset):
target_arrays: dict[str, dict[str, Sequence[int | float]]]
datasets: Iterable[CellMapDataset]
_weighted_sampler: Optional[WeightedRandomSampler]
_class_counts: dict[str, int] = {}
_class_counts: dict[str, dict[str, int]]

def __init__(
self,
Expand Down Expand Up @@ -53,6 +53,72 @@ def class_counts(self):
self._class_counts = class_counts
return self._class_counts

@property
def class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.classes) > 1:
class_counts = {}
class_count_sum = 0
for c in self.classes:
class_counts[c] = self.class_counts[c]["total"]
class_count_sum += class_counts[c]

class_weights = {
c: (
1 - (class_counts[c] / class_count_sum)
if class_counts[c] != class_count_sum
else 0.1
)
for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights

@property
def dataset_weights(self):
"""
Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset.
"""
class_weights = self.class_weights

dataset_weights = {}
for dataset in self.datasets:
dataset_weight = np.sum(
[
dataset.class_counts["totals"][c] * class_weights[c]
for c in self.classes
]
)
dataset_weights[dataset] = dataset_weight
return dataset_weights

@property
def sample_weights(self):
"""
Returns the weights for each sample in the multi-dataset based on the number of samples in each dataset.
"""

dataset_weights = self.dataset_weights
sample_weights = []
for dataset, dataset_weight in dataset_weights.items():
sample_weights += [dataset_weight] * len(dataset)
return sample_weights

@property
def validation_indices(self) -> Sequence[int]:
"""
Returns the indices of the validation set for each dataset in the multi-dataset.
"""
validation_indices = []
index_offset = 0
for dataset in self.datasets:
validation_indices.extend(dataset.validation_indices)
index_offset += len(dataset)
return validation_indices

def to(self, device: str):
for dataset in self.datasets:
dataset.to(device)
Expand All @@ -63,7 +129,7 @@ def weighted_sampler(
):
if self._weighted_sampler is None:
# TODO: calculate weights for each sample
sample_weights = self.get_sample_weights()
sample_weights = self.sample_weights

self._weighted_sampler = WeightedRandomSampler(
sample_weights, batch_size, replacement=False, generator=rng
Expand All @@ -85,7 +151,7 @@ def get_subset_random_sampler(
generator=rng,
)
else:
dataset_weights = list(self.get_dataset_weights().values())
dataset_weights = list(self.dataset_weights.values())

datasets_sampled = torch.multinomial(
torch.tensor(dataset_weights), num_samples, replacement=True
Expand All @@ -105,64 +171,6 @@ def get_subset_random_sampler(
indices = indices[torch.randperm(len(indices), generator=rng)]
return torch.utils.data.SubsetRandomSampler(indices, generator=rng)

def get_class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.classes) > 1:
class_counts = {c: 0 for c in self.classes}
class_count_sum = 0
for dataset in self.datasets:
for c in self.classes:
class_counts[c] += dataset.class_counts["totals"][c]
class_count_sum += dataset.class_counts["totals"][c]

class_weights = {
c: 1 - (class_counts[c] / class_count_sum) for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights

def get_dataset_weights(self):
"""
Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset.
"""
class_weights = self.get_class_weights()

dataset_weights = {}
for dataset in self.datasets:
dataset_weight = np.sum(
[
dataset.class_counts["totals"][c] * class_weights[c]
for c in self.classes
]
)
dataset_weights[dataset] = dataset_weight
return dataset_weights

def get_sample_weights(self):
"""
Returns the weights for each sample in the multi-dataset based on the number of samples in each dataset.
"""

dataset_weights = self.get_dataset_weights()
sample_weights = []
for dataset, dataset_weight in dataset_weights.items():
sample_weights += [dataset_weight] * len(dataset)
return sample_weights

def get_validation_indices(self) -> Sequence[int]:
"""
Returns the indices of the validation set for each dataset in the multi-dataset.
"""
validation_indices = []
index_offset = 0
for dataset in self.datasets:
validation_indices.extend(dataset.get_validation_indices())
index_offset += len(dataset)
return validation_indices

def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]:
"""Returns the indices of the dataset that will tile the dataset according to the chunk_size."""
indices = []
Expand Down
36 changes: 16 additions & 20 deletions src/cellmap_data/subdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.indices)

@property
def classes(self):
return self.dataset.classes

@property
def class_counts(self):
return self.dataset.class_counts

@property
def class_weights(self):
return self.dataset.class_weights

@property
def validation_indices(self):
return self.dataset.validation_indices

def to(self, device):
self.dataset.to(device)
return self
Expand All @@ -25,23 +41,3 @@ def set_raw_value_transforms(self, transforms: Callable):
def set_target_value_transforms(self, transforms: Callable):
"""Sets the target value transforms for the subset dataset."""
self.dataset.set_target_value_transforms(transforms)

def get_class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.dataset.classes) > 1:
class_counts = {c: 0 for c in self.dataset.classes}
class_count_sum = 0
for c in self.dataset.classes:
class_counts[c] += self.dataset.class_counts["totals"][c]
class_count_sum += self.dataset.class_counts["totals"][c]

class_weights = {
c: 1 - (class_counts[c] / class_count_sum) for c in self.dataset.classes
}
else:
class_weights = {
self.dataset.classes[0]: 0.1
} # less than 1 to avoid overflow
return class_weights

0 comments on commit 176d568

Please sign in to comment.