-
Notifications
You must be signed in to change notification settings - Fork 385
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add South Africa Crop Type Competition Dataset (#1840)
* initial commit * add init description * correct scanning of files in init * added random data generation * remove test directory * Update metadata * fix formatting * add s1 data and update regex * fix dataset and add tests * formatting * mypy * add timeseries test and take images from july * fix bug in merging files across dates * mypy * Bump lightly from 1.4.25 to 1.5.0 in /requirements (#1894) * Additional Satlas pretrained models (#1884) * Documentation, satellite-specific transform and weights for additional Satlas single-image rgb&multispectral Swin-v2 models. Tests pass. * Address 3 of comments * Address comments, fix readmydocs and isort, mypy still unhappy * update * Add bands to meta dicts * Add comment about Satlas S2 RGB using TCI product * linting --------- Co-authored-by: Piper Wolters <[email protected]> Co-authored-by: Piper Wolters <[email protected]> * Update tests/data/south_africa_crop_type/data.py Co-authored-by: Adam J. Stewart <[email protected]> * Update tests/data/south_africa_crop_type/data.py Co-authored-by: Adam J. Stewart <[email protected]> * Update torchgeo/datasets/south_africa_crop_type.py Co-authored-by: Adam J. Stewart <[email protected]> * Update torchgeo/datasets/south_africa_crop_type.py Co-authored-by: Adam J. Stewart <[email protected]> * Update torchgeo/datasets/south_africa_crop_type.py Co-authored-by: Adam J. Stewart <[email protected]> * Update torchgeo/datasets/south_africa_crop_type.py Co-authored-by: Adam J. Stewart <[email protected]> * refactor gettitem and bug fixes * typo * data.py style * Update torchgeo/datasets/south_africa_crop_type.py Co-authored-by: Adam J. Stewart <[email protected]> * Update torchgeo/datasets/south_africa_crop_type.py Co-authored-by: Adam J. Stewart <[email protected]> * added comments for nonstandard functionality * Update torchgeo/datasets/south_africa_crop_type.py Co-authored-by: Adam J. Stewart <[email protected]> * small change * add verbose documentation --------- Co-authored-by: georgehuber <“[email protected]”> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Piper Wolters <[email protected]> Co-authored-by: Piper Wolters <[email protected]> Co-authored-by: Piper Wolters <[email protected]> Co-authored-by: Adam J. Stewart <[email protected]>
- Loading branch information
1 parent
f6cd092
commit 1f6e974
Showing
35 changed files
with
457 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
|
||
import numpy as np | ||
import rasterio | ||
from rasterio.crs import CRS | ||
from rasterio.transform import Affine | ||
|
||
|
||
def generate_test_data() -> str: | ||
"""Create test data archive for SouthAfricaCropType dataset. | ||
Args: | ||
paths: path to store test data | ||
n_samples: number of samples. | ||
Returns: | ||
md5 hash of created archive | ||
""" | ||
paths = "south_africa_crop_type" | ||
dtype = np.uint8 | ||
dtype_max = np.iinfo(dtype).max | ||
|
||
SIZE = 256 | ||
|
||
np.random.seed(0) | ||
|
||
s1_bands = ("VH", "VV") | ||
s2_bands = ( | ||
"B01", | ||
"B02", | ||
"B03", | ||
"B04", | ||
"B05", | ||
"B06", | ||
"B07", | ||
"B08", | ||
"B8A", | ||
"B09", | ||
"B11", | ||
"B12", | ||
) | ||
|
||
profile = { | ||
"dtype": dtype, | ||
"width": SIZE, | ||
"height": SIZE, | ||
"count": 1, | ||
"crs": CRS.from_epsg(32634), | ||
"transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0), | ||
} | ||
|
||
train_imagery_s1_dir = os.path.join(paths, "train", "imagery", "s1") | ||
train_imagery_s2_dir = os.path.join(paths, "train", "imagery", "s2") | ||
train_labels_dir = os.path.join(paths, "train", "labels") | ||
|
||
os.makedirs(train_imagery_s1_dir, exist_ok=True) | ||
os.makedirs(train_imagery_s2_dir, exist_ok=True) | ||
os.makedirs(train_labels_dir, exist_ok=True) | ||
|
||
train_field_ids = ["12"] | ||
|
||
s1_timestamps = ["2017_04_01", "2017_07_28"] | ||
s2_timestamps = ["2017_05_04", "2017_07_22"] | ||
|
||
def write_raster(path: str, arr: np.array) -> None: | ||
with rasterio.open(path, "w", **profile) as src: | ||
src.write(arr, 1) | ||
|
||
for field_id in train_field_ids: | ||
for date in s1_timestamps: | ||
s1_dir = os.path.join(train_imagery_s1_dir, field_id, date) | ||
os.makedirs(s1_dir, exist_ok=True) | ||
for band in s1_bands: | ||
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) | ||
path = os.path.join(s1_dir, f"{field_id}_{date}_{band}_10m.tif") | ||
write_raster(path, train_arr) | ||
for date in s2_timestamps: | ||
s2_dir = os.path.join(train_imagery_s2_dir, field_id, date) | ||
os.makedirs(s2_dir, exist_ok=True) | ||
for band in s2_bands: | ||
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) | ||
path = os.path.join(s2_dir, f"{field_id}_{date}_{band}_10m.tif") | ||
write_raster(path, train_arr) | ||
label_path = os.path.join(train_labels_dir, f"{field_id}.tif") | ||
label_arr = np.random.randint(9, size=(SIZE, SIZE), dtype=dtype) | ||
write_raster(label_path, label_arr) | ||
|
||
|
||
if __name__ == "__main__": | ||
generate_test_data() |
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VH_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VV_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VH_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VV_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B01_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B02_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B03_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B04_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B05_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B06_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B07_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B08_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B09_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B11_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B12_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B8A_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B01_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B02_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B03_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B04_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B05_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B06_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B07_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B08_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B09_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B11_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B12_10m.tif
Binary file not shown.
Binary file added
BIN
+64.4 KB
tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B8A_10m.tif
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from rasterio.crs import CRS | ||
|
||
from torchgeo.datasets import ( | ||
BoundingBox, | ||
DatasetNotFoundError, | ||
IntersectionDataset, | ||
RGBBandsMissingError, | ||
SouthAfricaCropType, | ||
UnionDataset, | ||
) | ||
|
||
|
||
class TestSouthAfricaCropType: | ||
@pytest.fixture | ||
def dataset(self) -> SouthAfricaCropType: | ||
path = os.path.join("tests", "data", "south_africa_crop_type") | ||
transforms = nn.Identity() | ||
return SouthAfricaCropType(paths=path, transforms=transforms) | ||
|
||
def test_getitem(self, dataset: SouthAfricaCropType) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["crs"], CRS) | ||
assert isinstance(x["image"], torch.Tensor) | ||
assert isinstance(x["mask"], torch.Tensor) | ||
|
||
def test_and(self, dataset: SouthAfricaCropType) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: SouthAfricaCropType) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_already_downloaded(self, dataset: SouthAfricaCropType) -> None: | ||
SouthAfricaCropType(paths=dataset.paths) | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"): | ||
SouthAfricaCropType(str(tmp_path)) | ||
|
||
def test_plot(self, dataset: SouthAfricaCropType) -> None: | ||
x = dataset[dataset.bounds] | ||
dataset.plot(x, suptitle="Test") | ||
plt.close() | ||
|
||
def test_plot_prediction(self, dataset: SouthAfricaCropType) -> None: | ||
x = dataset[dataset.bounds] | ||
x["prediction"] = x["mask"].clone() | ||
dataset.plot(x, suptitle="Prediction") | ||
plt.close() | ||
|
||
def test_invalid_query(self, dataset: SouthAfricaCropType) -> None: | ||
query = BoundingBox(0, 0, 0, 0, 0, 0) | ||
with pytest.raises( | ||
IndexError, match="query: .* not found in index with bounds:" | ||
): | ||
dataset[query] | ||
|
||
def test_rgb_bands_absent_plot(self, dataset: SouthAfricaCropType) -> None: | ||
with pytest.raises( | ||
RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" | ||
): | ||
ds = SouthAfricaCropType(dataset.paths, bands=["B01", "B02", "B05"]) | ||
x = ds[ds.bounds] | ||
ds.plot(x, suptitle="Test") | ||
plt.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.