Skip to content

Commit

Permalink
Add South Africa Crop Type Competition Dataset (#1840)
Browse files Browse the repository at this point in the history
* initial commit

* add init description

* correct scanning of files in init

* added random data generation

* remove test directory

* Update metadata

* fix formatting

* add s1 data and update regex

* fix dataset and add tests

* formatting

* mypy

* add timeseries test and take images from july

* fix bug in merging files across dates

* mypy

* Bump lightly from 1.4.25 to 1.5.0 in /requirements (#1894)

* Additional Satlas pretrained models (#1884)

* Documentation, satellite-specific transform and weights for additional Satlas single-image rgb&multispectral Swin-v2 models. Tests pass.

* Address 3 of comments

* Address comments, fix readmydocs and isort, mypy still unhappy

* update

* Add bands to meta dicts

* Add comment about Satlas S2 RGB using TCI product

* linting

---------

Co-authored-by: Piper Wolters <[email protected]>
Co-authored-by: Piper Wolters <[email protected]>

* Update tests/data/south_africa_crop_type/data.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update tests/data/south_africa_crop_type/data.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update torchgeo/datasets/south_africa_crop_type.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update torchgeo/datasets/south_africa_crop_type.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update torchgeo/datasets/south_africa_crop_type.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update torchgeo/datasets/south_africa_crop_type.py

Co-authored-by: Adam J. Stewart <[email protected]>

* refactor gettitem and bug fixes

* typo

* data.py style

* Update torchgeo/datasets/south_africa_crop_type.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Update torchgeo/datasets/south_africa_crop_type.py

Co-authored-by: Adam J. Stewart <[email protected]>

* added comments for nonstandard functionality

* Update torchgeo/datasets/south_africa_crop_type.py

Co-authored-by: Adam J. Stewart <[email protected]>

* small change

* add verbose documentation

---------

Co-authored-by: georgehuber <“[email protected]”>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Piper Wolters <[email protected]>
Co-authored-by: Piper Wolters <[email protected]>
Co-authored-by: Piper Wolters <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
7 people authored Mar 15, 2024
1 parent f6cd092 commit 1f6e974
Show file tree
Hide file tree
Showing 35 changed files with 457 additions and 1 deletion.
5 changes: 4 additions & 1 deletion docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,16 @@ Sentinel
.. autoclass:: Sentinel1
.. autoclass:: Sentinel2

South Africa Crop Type
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: SouthAfricaCropType

South America Soybean
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: SouthAmericaSoybean


.. _Non-geospatial Datasets:

Non-geospatial Datasets
Expand Down
1 change: 1 addition & 0 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,-
`PRISMA`_,Imagery,PRISMA,-,512x512,5--30
`Sentinel`_,Imagery,Sentinel,"CC-BY-SA-3.0-IGO","10,000x10,000",10
`South Africa Crop Type`_,"Imagery, Masks",Sentinel-2,"CC-BY-4.0","256x256",10
`South America Soybean`_,Masks,"Landsat, MODIS",-,-,30
95 changes: 95 additions & 0 deletions tests/data/south_africa_crop_type/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_test_data() -> str:
"""Create test data archive for SouthAfricaCropType dataset.
Args:
paths: path to store test data
n_samples: number of samples.
Returns:
md5 hash of created archive
"""
paths = "south_africa_crop_type"
dtype = np.uint8
dtype_max = np.iinfo(dtype).max

SIZE = 256

np.random.seed(0)

s1_bands = ("VH", "VV")
s2_bands = (
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B11",
"B12",
)

profile = {
"dtype": dtype,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(32634),
"transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0),
}

train_imagery_s1_dir = os.path.join(paths, "train", "imagery", "s1")
train_imagery_s2_dir = os.path.join(paths, "train", "imagery", "s2")
train_labels_dir = os.path.join(paths, "train", "labels")

os.makedirs(train_imagery_s1_dir, exist_ok=True)
os.makedirs(train_imagery_s2_dir, exist_ok=True)
os.makedirs(train_labels_dir, exist_ok=True)

train_field_ids = ["12"]

s1_timestamps = ["2017_04_01", "2017_07_28"]
s2_timestamps = ["2017_05_04", "2017_07_22"]

def write_raster(path: str, arr: np.array) -> None:
with rasterio.open(path, "w", **profile) as src:
src.write(arr, 1)

for field_id in train_field_ids:
for date in s1_timestamps:
s1_dir = os.path.join(train_imagery_s1_dir, field_id, date)
os.makedirs(s1_dir, exist_ok=True)
for band in s1_bands:
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype)
path = os.path.join(s1_dir, f"{field_id}_{date}_{band}_10m.tif")
write_raster(path, train_arr)
for date in s2_timestamps:
s2_dir = os.path.join(train_imagery_s2_dir, field_id, date)
os.makedirs(s2_dir, exist_ok=True)
for band in s2_bands:
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype)
path = os.path.join(s2_dir, f"{field_id}_{date}_{band}_10m.tif")
write_raster(path, train_arr)
label_path = os.path.join(train_labels_dir, f"{field_id}.tif")
label_arr = np.random.randint(9, size=(SIZE, SIZE), dtype=dtype)
write_raster(label_path, label_arr)


if __name__ == "__main__":
generate_test_data()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
77 changes: 77 additions & 0 deletions tests/datasets/test_south_africa_crop_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
RGBBandsMissingError,
SouthAfricaCropType,
UnionDataset,
)


class TestSouthAfricaCropType:
@pytest.fixture
def dataset(self) -> SouthAfricaCropType:
path = os.path.join("tests", "data", "south_africa_crop_type")
transforms = nn.Identity()
return SouthAfricaCropType(paths=path, transforms=transforms)

def test_getitem(self, dataset: SouthAfricaCropType) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_and(self, dataset: SouthAfricaCropType) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: SouthAfricaCropType) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: SouthAfricaCropType) -> None:
SouthAfricaCropType(paths=dataset.paths)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SouthAfricaCropType(str(tmp_path))

def test_plot(self, dataset: SouthAfricaCropType) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: SouthAfricaCropType) -> None:
x = dataset[dataset.bounds]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: SouthAfricaCropType) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]

def test_rgb_bands_absent_plot(self, dataset: SouthAfricaCropType) -> None:
with pytest.raises(
RGBBandsMissingError, match="Dataset does not contain some of the RGB bands"
):
ds = SouthAfricaCropType(dataset.paths, bands=["B01", "B02", "B05"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from .sentinel import Sentinel, Sentinel1, Sentinel2
from .skippd import SKIPPD
from .so2sat import So2Sat
from .south_africa_crop_type import SouthAfricaCropType
from .south_america_soybean import SouthAmericaSoybean
from .spacenet import (
SpaceNet,
Expand Down Expand Up @@ -188,6 +189,7 @@
"Sentinel",
"Sentinel1",
"Sentinel2",
"SouthAfricaCropType",
"SouthAmericaSoybean",
# NonGeoDataset
"ADVANCE",
Expand Down
Loading

0 comments on commit 1f6e974

Please sign in to comment.