Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DL4GAMAlps dataset #2508

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ GID-15

.. autoclass:: GID15

Glacier Mapping Alps
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: GlacierMappingAlps

HySpecNet-11k
^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB
`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM"
`GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB
`Glacier Mapping Alps`_,S,"Sentinel-2","CC-BY-4.0","2,251 or 11,440","2","256x256","10","MSI"
dcodrut marked this conversation as resolved.
Show resolved Hide resolved
`HySpecNet-11k`_,-,EnMAP,CC0-1.0,11k,-,128,30,HSI
`IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB
`Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB
Expand Down
143 changes: 143 additions & 0 deletions tests/data/glacier_mapping_alps/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr

# define the patch size
PATCH_SIZE = 16

# create a random generator
rg = np.random.RandomState(42)


def create_dummy_sample(fp: str | Path) -> None:
# create the random S2 bands data; make the last two bands as binary masks
band_data = rg.randint(
low=0, high=10000, dtype=np.int16, size=(15, PATCH_SIZE, PATCH_SIZE)
)
band_data[-2:] = (band_data[-2:] > 5000).astype(np.int16)

data_dict = {
'band_data': {
'dims': ('band', 'y', 'x'),
'data': band_data,
'attrs': {
'long_name': [
'B1',
'B2',
'B3',
'B4',
'B5',
'B6',
'B7',
'B8',
'B8A',
'B9',
'B10',
'B11',
'B12',
'CLOUDLESS_MASK',
'FILL_MASK',
],
'_FillValue': -9999,
},
},
'mask_all_g_id': { # glaciers mask (with -1 for no-glacier and GLACIER_ID for glacier)
'dims': ('y', 'x'),
'data': rg.choice([-1, 8, 9, 30, 35], size=(PATCH_SIZE, PATCH_SIZE)).astype(
np.int32
),
'attrs': {'_FillValue': -1},
},
'mask_debris': {
'dims': ('y', 'x'),
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) > 0.5).astype(np.int8),
'attrs': {'_FillValue': -1},
},
}

# add the additional variables
for v in [
'dem',
'slope',
'aspect',
'planform_curvature',
'profile_curvature',
'terrain_ruggedness_index',
'dhdt',
'v',
]:
data_dict[v] = {
'dims': ('y', 'x'),
'data': (rg.random((PATCH_SIZE, PATCH_SIZE)) * 100).astype(np.float32),
'attrs': {'_FillValue': -9999},
}

# create the xarray dataset and save it
nc = xr.Dataset.from_dict(data_dict)
nc.to_netcdf(fp)


def create_splits_df(fp: str | Path) -> pd.DataFrame:
# create a dataframe with the splits for the 4 glaciers
splits_df = pd.DataFrame(
{
'entry_id': ['g_0008', 'g_0009', 'g_0030', 'g_0035'],
'split_1': ['fold_train', 'fold_train', 'fold_valid', 'fold_test'],
'split_2': ['fold_train', 'fold_valid', 'fold_train', 'fold_test'],
'split_3': ['fold_train', 'fold_valid', 'fold_test', 'fold_train'],
'split_4': ['fold_test', 'fold_valid', 'fold_train', 'fold_train'],
'split_5': ['fold_test', 'fold_train', 'fold_train', 'fold_valid'],
}
)

splits_df.to_csv(fp_splits, index=False)
print(f'Splits dataframe saved to {fp_splits}')
return splits_df


if __name__ == '__main__':
# prepare the paths
fp_splits = Path('splits.csv')
fp_dir_ds_small = Path('dataset_small')
fp_dir_ds_large = Path('dataset_large')

# cleanup
fp_splits.unlink(missing_ok=True)
fp_dir_ds_small.with_suffix('.tar.gz').unlink(missing_ok=True)
fp_dir_ds_large.with_suffix('.tar.gz').unlink(missing_ok=True)
shutil.rmtree(fp_dir_ds_small, ignore_errors=True)
shutil.rmtree(fp_dir_ds_large, ignore_errors=True)

# create the splits dataframe
split_df = create_splits_df(fp_splits)

# create the two datasets versions (small and large) with 1 and 2 patches per glacier, respectively
for fp_dir, num_patches in zip([fp_dir_ds_small, fp_dir_ds_large], [1, 2]):
for glacier_id in split_df.entry_id:
for i in range(num_patches):
fp = fp_dir / glacier_id / f'{glacier_id}_patch_{i}.nc'
fp.parent.mkdir(parents=True, exist_ok=True)
create_dummy_sample(fp=fp)

# archive the datasets
for fp_dir in [fp_dir_ds_small, fp_dir_ds_large]:
shutil.make_archive(str(fp_dir), 'gztar', fp_dir)

# compute checksums
for fp in [
fp_dir_ds_small.with_suffix('.tar.gz'),
fp_dir_ds_large.with_suffix('.tar.gz'),
fp_splits,
]:
with open(fp, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'md5 for {fp}: {md5}')
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/data/glacier_mapping_alps/splits.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
entry_id,split_1,split_2,split_3,split_4,split_5
g_0008,fold_train,fold_train,fold_train,fold_test,fold_test
g_0009,fold_train,fold_valid,fold_valid,fold_valid,fold_train
g_0030,fold_valid,fold_train,fold_test,fold_train,fold_train
g_0035,fold_test,fold_test,fold_train,fold_train,fold_valid
117 changes: 117 additions & 0 deletions tests/datasets/test_glacier_mapping_alps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, GlacierMappingAlps


class TestGlacierMappingAlps:
@pytest.fixture(
params=zip(
['train', 'val', 'test'],
[1, 3, 5],
['small', 'small', 'large'],
[
GlacierMappingAlps.rgb_bands,
GlacierMappingAlps.rgb_nir_swir_bands,
GlacierMappingAlps.all_bands,
],
[None, ['dem'], GlacierMappingAlps.extra_features_all],
)
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> GlacierMappingAlps:
r_url = Path('tests', 'data', 'glacier_mapping_alps')
download_metadata = {
'dataset_small': {
'url': str(r_url / 'dataset_small.tar.gz'),
'checksum': '35f85360b943caa8661d9fb573b0f0b5',
},
'dataset_large': {
'url': str(r_url / 'dataset_large.tar.gz'),
'checksum': '636be5be35b8bd1e7771e9010503e4bc',
},
'splits_csv': {
'url': str(r_url / 'splits.csv'),
'checksum': '973367465c8ab322d0cf544a345b02f5',
},
}

monkeypatch.setattr(GlacierMappingAlps, 'download_metadata', download_metadata)
root = tmp_path
split, cv_iter, version, bands, extra_features = request.param
transforms = nn.Identity()
return GlacierMappingAlps(
root,
split,
cv_iter,
version,
bands,
extra_features,
transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: GlacierMappingAlps) -> None:
x = dataset[0]
assert isinstance(x, dict)

var_names = ['image', 'mask_glacier', 'mask_debris', 'mask_clouds_and_shadows']
if dataset.extra_features:
var_names += list(dataset.extra_features)
for v in var_names:
assert v in x
assert isinstance(x[v], torch.Tensor)

# check if all variables have the same spatial dimensions as the image
assert x['image'].shape[-2:] == x[v].shape[-2:]

# check the first dimension of the image tensor
assert x['image'].shape[0] == len(dataset.bands)

def test_len(self, dataset: GlacierMappingAlps) -> None:
num_glaciers_per_fold = 2 if dataset.split == 'train' else 1
num_patches_per_glacier = 1 if dataset.version == 'small' else 2
assert len(dataset) == num_glaciers_per_fold * num_patches_per_glacier

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GlacierMappingAlps(tmp_path)

def test_already_downloaded_and_extracted(
self, dataset: GlacierMappingAlps
) -> None:
GlacierMappingAlps(root=dataset.root, download=False, version=dataset.version)

def test_already_downloaded_but_not_yet_extracted(self, tmp_path: Path) -> None:
fp_archive = Path(
'tests', 'data', 'glacier_mapping_alps', 'dataset_small.tar.gz'
)
shutil.copyfile(fp_archive, Path(str(tmp_path), fp_archive.name))
fp_splits = Path('tests', 'data', 'glacier_mapping_alps', 'splits.csv')
shutil.copyfile(fp_splits, Path(str(tmp_path), fp_splits.name))
GlacierMappingAlps(root=str(tmp_path), download=False)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
GlacierMappingAlps(split='foo')

def test_plot(self, dataset: GlacierMappingAlps) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = torch.clone(sample['mask_glacier'])
dataset.plot(sample, suptitle='Test with prediction')
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
)
from .geonrw import GeoNRW
from .gid15 import GID15
from .glacier_mapping_alps import GlacierMappingAlps
from .globbiomass import GlobBiomass
from .hyspecnet import HySpecNet11k
from .idtrees import IDTReeS
Expand Down Expand Up @@ -215,6 +216,7 @@
'ForestDamage',
'GeoDataset',
'GeoNRW',
'GlacierMappingAlps',
'GlobBiomass',
'HySpecNet11k',
'IDTReeS',
Expand Down
Loading
Loading