Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Jan 28, 2025
1 parent 4b343f2 commit aff5aaa
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tests/datasets/test_ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_getitem(self, dataset: SSL4EOL) -> None:
assert isinstance(x['image'], torch.Tensor)
assert (
x['image'].size(0)
== dataset.seasons * dataset.metadata[dataset.split]['num_bands']
== dataset.seasons * len(dataset.metadata[dataset.split]['all_bands'])
)

def test_len(self, dataset: SSL4EOL) -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class Landsat7(Landsat):

filename_glob = 'LE07_*_{}.*'

default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7')
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')

wavelengths: ClassVar[dict[str, float]] = {
Expand Down
47 changes: 37 additions & 10 deletions torchgeo/datasets/ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,42 @@ class SSL4EOL(SSL4EO):
"""

class _Metadata(TypedDict):
num_bands: int
all_bands: list[str]
rgb_bands: list[int]

metadata: ClassVar[dict[str, _Metadata]] = {
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
'oli_tirs_toa': {'num_bands': 11, 'rgb_bands': [3, 2, 1]},
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
'tm_toa': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
'rgb_bands': [2, 1, 0],
},
'etm_toa': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B6', 'B7', 'B8'],
'rgb_bands': [2, 1, 0],
},
'etm_sr': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'],
'rgb_bands': [2, 1, 0],
},
'oli_tirs_toa': {
'all_bands': [
'B1',
'B2',
'B3',
'B4',
'B5',
'B6',
'B7',
'B8',
'B9',
'B10',
'B11',
],
'rgb_bands': [3, 2, 1],
},
'oli_sr': {
'all_bands': ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'],
'rgb_bands': [3, 2, 1],
},
}

url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
Expand Down Expand Up @@ -212,8 +239,8 @@ def __init__(
base = Landsat8

self.wavelengths = []
for band in range(1, self.metadata[split]['num_bands'] + 1):
self.wavelengths.append(base.wavelengths[f'B{band}'])
for band in self.metadata[split]['all_bands']:
self.wavelengths.append(base.wavelengths[band])

self.scenes = sorted(os.listdir(self.subdir))

Expand All @@ -236,7 +263,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
ts = []
wavelengths = []
for subdir in subdirs:
mint, maxt = disambiguate_timestamp(subdir[:-8], Landsat.date_format)
mint, maxt = disambiguate_timestamp(subdir[-8:], Landsat.date_format)
directory = os.path.join(root, subdir)
filename = os.path.join(directory, 'all_bands.tif')
with rasterio.open(filename) as f:
Expand Down Expand Up @@ -338,7 +365,7 @@ def plot(
fig, axes = plt.subplots(
ncols=self.seasons, squeeze=False, figsize=(4 * self.seasons, 4)
)
num_bands = self.metadata[self.split]['num_bands']
num_bands = len(self.metadata[self.split]['all_bands'])
rgb_bands = self.metadata[self.split]['rgb_bands']

for i in range(self.seasons):
Expand Down

0 comments on commit aff5aaa

Please sign in to comment.