Skip to content

Commit

Permalink
added identity for init values
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuld committed Jan 19, 2025
1 parent 9a050bd commit 8c918a8
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions torchgeo/datamodules/substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ def __init__(
self.image_resize = image_resize
self.mask_resize = mask_resize
self.num_of_timepoints = num_of_timepoints

self.geo_transforms = geo_transforms if geo_transforms is not None else self._identity
self.color_transforms = color_transforms if color_transforms is not None else self._identity
self.image_resize = image_resize if image_resize is not None else self._identity
self.mask_resize = mask_resize if mask_resize is not None else self._identity

def _identity(self, x: torch.Tensor) -> torch.Tensor:
"""Identity function for default transformations."""
return x

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand All @@ -98,7 +106,6 @@ def setup(self, stage: str) -> None:
val_len = int(total_len * self.val_split_pct)
test_len = int(total_len * self.test_split_pct)
train_len = total_len - val_len - test_len
print(val_len, test_len, train_len)

self.train_dataset, self.val_dataset, self.test_dataset = random_split(
dataset, [train_len, val_len, test_len], generator
Expand All @@ -123,12 +130,17 @@ def _apply_transforms(self, dataset: Subset[Any]) -> Subset[Any]:
image, mask = sample['image'], sample['mask']

if self.geo_transforms:
print(self.geo_transforms)
if mask.shape.__len__() == 2:
mask = mask.unsqueeze(0)
combined = torch.cat((image, mask), 0)
combined = self.geo_transforms(combined)
image, mask = torch.split(combined, [image.shape[0], mask.shape[0]], 0)
if mask.shape[0] == 1:
mask = mask.squeeze(0)

if self.color_transforms:
num_timepoints = image.shape[0] // self.bands
num_timepoints = image.shape[0] // self.bands.__len__()
for i in range(num_timepoints):
start = i * len(self.bands)
end = start + 3
Expand Down

0 comments on commit 8c918a8

Please sign in to comment.