diff --git a/torchgeo/datamodules/substation.py b/torchgeo/datamodules/substation.py index 37c72d6e952..2ff369acd44 100644 --- a/torchgeo/datamodules/substation.py +++ b/torchgeo/datamodules/substation.py @@ -75,7 +75,15 @@ def __init__( self.image_resize = image_resize self.mask_resize = mask_resize self.num_of_timepoints = num_of_timepoints - + self.geo_transforms = geo_transforms if geo_transforms is not None else self._identity + self.color_transforms = color_transforms if color_transforms is not None else self._identity + self.image_resize = image_resize if image_resize is not None else self._identity + self.mask_resize = mask_resize if mask_resize is not None else self._identity + + def _identity(self, x: torch.Tensor) -> torch.Tensor: + """Identity function for default transformations.""" + return x + def setup(self, stage: str) -> None: """Set up datasets. @@ -98,7 +106,6 @@ def setup(self, stage: str) -> None: val_len = int(total_len * self.val_split_pct) test_len = int(total_len * self.test_split_pct) train_len = total_len - val_len - test_len - print(val_len, test_len, train_len) self.train_dataset, self.val_dataset, self.test_dataset = random_split( dataset, [train_len, val_len, test_len], generator @@ -123,12 +130,17 @@ def _apply_transforms(self, dataset: Subset[Any]) -> Subset[Any]: image, mask = sample['image'], sample['mask'] if self.geo_transforms: + print(self.geo_transforms) + if mask.shape.__len__() == 2: + mask = mask.unsqueeze(0) combined = torch.cat((image, mask), 0) combined = self.geo_transforms(combined) image, mask = torch.split(combined, [image.shape[0], mask.shape[0]], 0) + if mask.shape[0] == 1: + mask = mask.squeeze(0) if self.color_transforms: - num_timepoints = image.shape[0] // self.bands + num_timepoints = image.shape[0] // self.bands.__len__() for i in range(num_timepoints): start = i * len(self.bands) end = start + 3