Skip to content

Commit

Permalink
feat: ⚡️ Add get_class_weights to dataset and subdataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 8, 2024
1 parent c2a5be8 commit 1f6674a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ dependencies = [
"torchvision",
"numpy",
"matplotlib",
"fibsem-tools",
"cellmap_schemas",
"fibsem-tools==6.3.2",
# "fibsem-tools",
"cellmap-schemas",
"tensorstore",
"xarray-tensorstore",
"cellpose",
Expand Down
18 changes: 18 additions & 0 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,24 @@ def set_target_value_transforms(self, transforms: Callable):
if isinstance(source, CellMapImage):
source.value_transform = transforms

def get_class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.classes) > 1:
class_counts = {c: 0 for c in self.classes}
class_count_sum = 0
for c in self.classes:
class_counts[c] += self.class_counts["totals"][c]
class_count_sum += self.class_counts["totals"][c]

class_weights = {
c: 1 - (class_counts[c] / class_count_sum) for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights


# Example input arrays:
# {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)},
Expand Down
29 changes: 29 additions & 0 deletions src/cellmap_data/subdataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
from torch.utils.data import Dataset


Expand All @@ -16,3 +17,31 @@ def __len__(self):
def to(self, device):
self.dataset.to(device)
return self

def set_raw_value_transforms(self, transforms: Callable):
"""Sets the raw value transforms for the subset dataset."""
self.dataset.set_raw_value_transforms(transforms)

def set_target_value_transforms(self, transforms: Callable):
"""Sets the target value transforms for the subset dataset."""
self.dataset.set_target_value_transforms(transforms)

def get_class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.dataset.classes) > 1:
class_counts = {c: 0 for c in self.dataset.classes}
class_count_sum = 0
for c in self.dataset.classes:
class_counts[c] += self.dataset.class_counts["totals"][c]
class_count_sum += self.dataset.class_counts["totals"][c]

class_weights = {
c: 1 - (class_counts[c] / class_count_sum) for c in self.dataset.classes
}
else:
class_weights = {
self.dataset.classes[0]: 0.1
} # less than 1 to avoid overflow
return class_weights

0 comments on commit 1f6674a

Please sign in to comment.