Skip to content

Commit

Permalink
refactor: 🐛 Remove hidden property declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 13, 2024
1 parent e2fd269 commit 68bbdff
Showing 1 changed file with 55 additions and 48 deletions.
103 changes: 55 additions & 48 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __repr__(self) -> str:
def class_counts(self):
if (
not hasattr(self, "_class_counts")
or self._class_counts is None
or self._class_counts is None # This should be overkill...
or len(self._class_counts) == 0
):
class_counts = {}
Expand All @@ -60,71 +60,78 @@ def class_weights(self):
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""
if len(self.classes) > 1:
class_counts = {}
class_count_sum = 0
for c in self.classes:
class_counts[c] = self.class_counts[c]["total"]
class_count_sum += class_counts[c]

class_weights = {
c: (
1 - (class_counts[c] / class_count_sum)
if class_counts[c] != class_count_sum
else 0.1
)
for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights
if not hasattr(self, "_class_weights"):
if len(self.classes) > 1:
class_counts = {}
class_count_sum = 0
for c in self.classes:
class_counts[c] = self.class_counts[c]["total"]
class_count_sum += class_counts[c]

class_weights = {
c: (
1 - (class_counts[c] / class_count_sum)
if class_counts[c] != class_count_sum
else 0.1
)
for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
self._class_weights = class_weights
return self._class_weights

@property
def dataset_weights(self):
"""
Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset.
"""
class_weights = self.class_weights

dataset_weights = {}
for dataset in self.datasets:
dataset_weight = np.sum(
[
dataset.class_counts["totals"][c] * class_weights[c]
for c in self.classes
]
)
dataset_weights[dataset] = dataset_weight
return dataset_weights
if not hasattr(self, "_dataset_weights"):
class_weights = self.class_weights

dataset_weights = {}
for dataset in self.datasets:
dataset_weight = np.sum(
[
dataset.class_counts["totals"][c] * class_weights[c]
for c in self.classes
]
)
dataset_weights[dataset] = dataset_weight
self._dataset_weights = dataset_weights
return self._dataset_weights

@property
def sample_weights(self):
"""
Returns the weights for each sample in the multi-dataset based on the number of samples in each dataset.
"""

dataset_weights = self.dataset_weights
sample_weights = []
for dataset, dataset_weight in dataset_weights.items():
sample_weights += [dataset_weight] * len(dataset)
return sample_weights
if not hasattr(self, "_sample_weights"):
dataset_weights = self.dataset_weights
sample_weights = []
for dataset, dataset_weight in dataset_weights.items():
sample_weights += [dataset_weight] * len(dataset)
self._sample_weights = sample_weights
return self._sample_weights

@property
def validation_indices(self) -> Sequence[int]:
"""
Returns the indices of the validation set for each dataset in the multi-dataset.
"""
validation_indices = []
index_offset = 0
for dataset in self.datasets:
try:
validation_indices.extend(dataset.validation_indices)
except AttributeError:
UserWarning(
f"Unable to get validation indices for dataset {dataset}\n skipping"
)
index_offset += len(dataset)
return list(np.array(validation_indices) + index_offset)
if not hasattr(self, "_validation_indices"):
validation_indices = []
index_offset = 0
for dataset in self.datasets:
try:
validation_indices.extend(dataset.validation_indices)
except AttributeError:
UserWarning(
f"Unable to get validation indices for dataset {dataset}\n skipping"
)
index_offset += len(dataset)
self._validation_indices = list(np.array(validation_indices) + index_offset)
return self._validation_indices

def to(self, device: str):
for dataset in self.datasets:
Expand Down

0 comments on commit 68bbdff

Please sign in to comment.