diff --git a/pyproject.toml b/pyproject.toml index 2e6dfa9..b27b5c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,9 @@ dependencies = [ "torchvision", "numpy", "matplotlib", - "fibsem-tools", - "cellmap_schemas", + "fibsem-tools==6.3.2", + # "fibsem-tools", + "cellmap-schemas", "tensorstore", "xarray-tensorstore", "cellpose", diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 55cdd88..b3be77f 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -425,6 +425,24 @@ def set_target_value_transforms(self, transforms: Callable): if isinstance(source, CellMapImage): source.value_transform = transforms + def get_class_weights(self): + """ + Returns the class weights for the multi-dataset based on the number of samples in each class. + """ + if len(self.classes) > 1: + class_counts = {c: 0 for c in self.classes} + class_count_sum = 0 + for c in self.classes: + class_counts[c] += self.class_counts["totals"][c] + class_count_sum += self.class_counts["totals"][c] + + class_weights = { + c: 1 - (class_counts[c] / class_count_sum) for c in self.classes + } + else: + class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow + return class_weights + # Example input arrays: # {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)}, diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index 43f455f..0446b28 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -1,3 +1,4 @@ +from typing import Callable from torch.utils.data import Dataset @@ -16,3 +17,31 @@ def __len__(self): def to(self, device): self.dataset.to(device) return self + + def set_raw_value_transforms(self, transforms: Callable): + """Sets the raw value transforms for the subset dataset.""" + self.dataset.set_raw_value_transforms(transforms) + + def set_target_value_transforms(self, transforms: Callable): + """Sets the target value transforms for the subset dataset.""" + self.dataset.set_target_value_transforms(transforms) + + def get_class_weights(self): + """ + Returns the class weights for the multi-dataset based on the number of samples in each class. + """ + if len(self.dataset.classes) > 1: + class_counts = {c: 0 for c in self.dataset.classes} + class_count_sum = 0 + for c in self.dataset.classes: + class_counts[c] += self.dataset.class_counts["totals"][c] + class_count_sum += self.dataset.class_counts["totals"][c] + + class_weights = { + c: 1 - (class_counts[c] / class_count_sum) for c in self.dataset.classes + } + else: + class_weights = { + self.dataset.classes[0]: 0.1 + } # less than 1 to avoid overflow + return class_weights