Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add South Africa Crop Type Competition Dataset #1840

Merged
merged 40 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e8c1b8b
initial commit
Feb 1, 2024
abdb677
Merge branch 'main' into add-south-africa-crop-type
Feb 2, 2024
0df3cec
add init description
Feb 2, 2024
eb33e95
correct scanning of files in init
Feb 2, 2024
7447d0d
added random data generation
Feb 9, 2024
4973565
remove test directory
Feb 9, 2024
3590b8e
Update metadata
Feb 9, 2024
b123a57
fix formatting
Feb 9, 2024
e8ed0ac
add s1 data and update regex
Feb 15, 2024
eaef967
fix dataset and add tests
Feb 16, 2024
033ccdc
Merge branch 'main' into add-south-africa-crop-type
GeorgeHuber Feb 16, 2024
dd1a49a
formatting
Feb 16, 2024
d2108cc
mypy
Feb 16, 2024
49872f8
add timeseries test and take images from july
Feb 19, 2024
c21ae9a
Merge branch 'main' into add-south-africa-crop-type
GeorgeHuber Feb 19, 2024
27cf49f
fix bug in merging files across dates
Feb 19, 2024
7fa5145
mypy
Feb 20, 2024
990c013
Merge branch 'main' into add-south-africa-crop-type
GeorgeHuber Feb 22, 2024
f281cce
Bump lightly from 1.4.25 to 1.5.0 in /requirements (#1894)
dependabot[bot] Feb 20, 2024
c79d600
Additional Satlas pretrained models (#1884)
piperwolters Feb 21, 2024
e19fa65
Update tests/data/south_africa_crop_type/data.py
GeorgeHuber Feb 28, 2024
a38f9f4
Update tests/data/south_africa_crop_type/data.py
GeorgeHuber Feb 28, 2024
d8d9bb2
Update torchgeo/datasets/south_africa_crop_type.py
GeorgeHuber Feb 28, 2024
1553389
Update torchgeo/datasets/south_africa_crop_type.py
GeorgeHuber Feb 28, 2024
9ae8924
Update torchgeo/datasets/south_africa_crop_type.py
GeorgeHuber Feb 28, 2024
afea568
Merge branch 'add-south-africa-crop-type' of https://github.com/Georg…
Feb 28, 2024
3c940ee
Update torchgeo/datasets/south_africa_crop_type.py
GeorgeHuber Feb 28, 2024
cb3f03c
refactor gettitem and bug fixes
Mar 1, 2024
6525a74
Merge branch 'add-south-africa-crop-type' of https://github.com/Georg…
Mar 1, 2024
c59b0b4
typo
Mar 1, 2024
a560598
data.py style
Mar 1, 2024
dd17b57
Merge branch 'main' into add-south-africa-crop-type
GeorgeHuber Mar 1, 2024
e03b663
Update torchgeo/datasets/south_africa_crop_type.py
GeorgeHuber Mar 5, 2024
a568af3
Update torchgeo/datasets/south_africa_crop_type.py
GeorgeHuber Mar 5, 2024
dec3948
Merge branch 'microsoft:main' into add-south-africa-crop-type
GeorgeHuber Mar 5, 2024
96ce979
added comments for nonstandard functionality
Mar 5, 2024
78d9122
Merge branch 'main' into add-south-africa-crop-type
GeorgeHuber Mar 7, 2024
fcce45b
Update torchgeo/datasets/south_africa_crop_type.py
GeorgeHuber Mar 8, 2024
f33e5c5
small change
Mar 8, 2024
c999e89
add verbose documentation
Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,16 @@ Sentinel
.. autoclass:: Sentinel1
.. autoclass:: Sentinel2

South Africa Crop Type
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: SouthAfricaCropType

South America Soybean
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: SouthAmericaSoybean


.. _Non-geospatial Datasets:

Non-geospatial Datasets
Expand Down
1 change: 1 addition & 0 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,-
`PRISMA`_,Imagery,PRISMA,-,512x512,5--30
`Sentinel`_,Imagery,Sentinel,"CC-BY-SA-3.0-IGO","10,000x10,000",10
`South Africa Crop Type`_,"Imagery, Masks",Sentinel-2,"CC-BY-4.0","256x256",10
`South America Soybean`_,Masks,"Landsat, MODIS",-,-,30
95 changes: 95 additions & 0 deletions tests/data/south_africa_crop_type/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_test_data() -> str:
"""Create test data archive for SouthAfricaCropType dataset.

Args:
paths: path to store test data
n_samples: number of samples.

Returns:
md5 hash of created archive
"""
paths = "south_africa_crop_type"
dtype = np.uint8
dtype_max = np.iinfo(dtype).max

SIZE = 256

np.random.seed(0)

s1_bands = ("VH", "VV")
s2_bands = (
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B11",
"B12",
)

profile = {
"dtype": dtype,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(32634),
"transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0),
}

train_imagery_s1_dir = os.path.join(paths, "train", "imagery", "s1")
train_imagery_s2_dir = os.path.join(paths, "train", "imagery", "s2")
train_labels_dir = os.path.join(paths, "train", "labels")

os.makedirs(train_imagery_s1_dir, exist_ok=True)
os.makedirs(train_imagery_s2_dir, exist_ok=True)
os.makedirs(train_labels_dir, exist_ok=True)

train_field_ids = ["12"]

s1_timestamps = ["2017_04_01", "2017_07_28"]
s2_timestamps = ["2017_05_04", "2017_07_22"]

def write_raster(path: str, arr: np.array) -> None:
with rasterio.open(path, "w", **profile) as src:
src.write(arr, 1)

for field_id in train_field_ids:
for date in s1_timestamps:
s1_dir = os.path.join(train_imagery_s1_dir, field_id, date)
os.makedirs(s1_dir, exist_ok=True)
for band in s1_bands:
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype)
path = os.path.join(s1_dir, f"{field_id}_{date}_{band}_10m.tif")
write_raster(path, train_arr)
for date in s2_timestamps:
s2_dir = os.path.join(train_imagery_s2_dir, field_id, date)
os.makedirs(s2_dir, exist_ok=True)
for band in s2_bands:
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype)
path = os.path.join(s2_dir, f"{field_id}_{date}_{band}_10m.tif")
write_raster(path, train_arr)
label_path = os.path.join(train_labels_dir, f"{field_id}.tif")
label_arr = np.random.randint(9, size=(SIZE, SIZE), dtype=dtype)
write_raster(label_path, label_arr)


if __name__ == "__main__":
generate_test_data()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
77 changes: 77 additions & 0 deletions tests/datasets/test_south_africa_crop_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
RGBBandsMissingError,
SouthAfricaCropType,
UnionDataset,
)


class TestSouthAfricaCropType:
@pytest.fixture
def dataset(self) -> SouthAfricaCropType:
path = os.path.join("tests", "data", "south_africa_crop_type")
transforms = nn.Identity()
return SouthAfricaCropType(paths=path, transforms=transforms)

def test_getitem(self, dataset: SouthAfricaCropType) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_and(self, dataset: SouthAfricaCropType) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: SouthAfricaCropType) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: SouthAfricaCropType) -> None:
SouthAfricaCropType(paths=dataset.paths)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SouthAfricaCropType(str(tmp_path))

def test_plot(self, dataset: SouthAfricaCropType) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: SouthAfricaCropType) -> None:
x = dataset[dataset.bounds]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: SouthAfricaCropType) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]

def test_rgb_bands_absent_plot(self, dataset: SouthAfricaCropType) -> None:
with pytest.raises(
RGBBandsMissingError, match="Dataset does not contain some of the RGB bands"
):
ds = SouthAfricaCropType(dataset.paths, bands=["B01", "B02", "B05"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from .sentinel import Sentinel, Sentinel1, Sentinel2
from .skippd import SKIPPD
from .so2sat import So2Sat
from .south_africa_crop_type import SouthAfricaCropType
from .south_america_soybean import SouthAmericaSoybean
from .spacenet import (
SpaceNet,
Expand Down Expand Up @@ -188,6 +189,7 @@
"Sentinel",
"Sentinel1",
"Sentinel2",
"SouthAfricaCropType",
"SouthAmericaSoybean",
# NonGeoDataset
"ADVANCE",
Expand Down
Loading
Loading