diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 3f9539db034..341be4d4916 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -171,13 +171,16 @@ Sentinel .. autoclass:: Sentinel1 .. autoclass:: Sentinel2 +South Africa Crop Type +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: SouthAfricaCropType South America Soybean ^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: SouthAmericaSoybean - .. _Non-geospatial Datasets: Non-geospatial Datasets diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index a44a82c77ea..bdce565c0cc 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -25,4 +25,5 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,- `PRISMA`_,Imagery,PRISMA,-,512x512,5--30 `Sentinel`_,Imagery,Sentinel,"CC-BY-SA-3.0-IGO","10,000x10,000",10 +`South Africa Crop Type`_,"Imagery, Masks",Sentinel-2,"CC-BY-4.0","256x256",10 `South America Soybean`_,Masks,"Landsat, MODIS",-,-,30 diff --git a/tests/data/south_africa_crop_type/data.py b/tests/data/south_africa_crop_type/data.py new file mode 100644 index 00000000000..dcb7a1d4b6d --- /dev/null +++ b/tests/data/south_africa_crop_type/data.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + + +def generate_test_data() -> str: + """Create test data archive for SouthAfricaCropType dataset. + + Args: + paths: path to store test data + n_samples: number of samples. + + Returns: + md5 hash of created archive + """ + paths = "south_africa_crop_type" + dtype = np.uint8 + dtype_max = np.iinfo(dtype).max + + SIZE = 256 + + np.random.seed(0) + + s1_bands = ("VH", "VV") + s2_bands = ( + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12", + ) + + profile = { + "dtype": dtype, + "width": SIZE, + "height": SIZE, + "count": 1, + "crs": CRS.from_epsg(32634), + "transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0), + } + + train_imagery_s1_dir = os.path.join(paths, "train", "imagery", "s1") + train_imagery_s2_dir = os.path.join(paths, "train", "imagery", "s2") + train_labels_dir = os.path.join(paths, "train", "labels") + + os.makedirs(train_imagery_s1_dir, exist_ok=True) + os.makedirs(train_imagery_s2_dir, exist_ok=True) + os.makedirs(train_labels_dir, exist_ok=True) + + train_field_ids = ["12"] + + s1_timestamps = ["2017_04_01", "2017_07_28"] + s2_timestamps = ["2017_05_04", "2017_07_22"] + + def write_raster(path: str, arr: np.array) -> None: + with rasterio.open(path, "w", **profile) as src: + src.write(arr, 1) + + for field_id in train_field_ids: + for date in s1_timestamps: + s1_dir = os.path.join(train_imagery_s1_dir, field_id, date) + os.makedirs(s1_dir, exist_ok=True) + for band in s1_bands: + train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) + path = os.path.join(s1_dir, f"{field_id}_{date}_{band}_10m.tif") + write_raster(path, train_arr) + for date in s2_timestamps: + s2_dir = os.path.join(train_imagery_s2_dir, field_id, date) + os.makedirs(s2_dir, exist_ok=True) + for band in s2_bands: + train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype) + path = os.path.join(s2_dir, f"{field_id}_{date}_{band}_10m.tif") + write_raster(path, train_arr) + label_path = os.path.join(train_labels_dir, f"{field_id}.tif") + label_arr = np.random.randint(9, size=(SIZE, SIZE), dtype=dtype) + write_raster(label_path, label_arr) + + +if __name__ == "__main__": + generate_test_data() diff --git a/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VH_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VH_10m.tif new file mode 100644 index 00000000000..af3f4d93160 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VH_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VV_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VV_10m.tif new file mode 100644 index 00000000000..1dbfe387b2d Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_04_01/12_2017_04_01_VV_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VH_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VH_10m.tif new file mode 100644 index 00000000000..8440a17ef69 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VH_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VV_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VV_10m.tif new file mode 100644 index 00000000000..8fdb00a39eb Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s1/12/2017_07_28/12_2017_07_28_VV_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B01_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B01_10m.tif new file mode 100644 index 00000000000..a30a969216b Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B01_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B02_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B02_10m.tif new file mode 100644 index 00000000000..a20c6c74117 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B02_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B03_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B03_10m.tif new file mode 100644 index 00000000000..f168dda07f2 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B03_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B04_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B04_10m.tif new file mode 100644 index 00000000000..b67a6a3949b Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B04_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B05_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B05_10m.tif new file mode 100644 index 00000000000..7d91dba5b1e Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B05_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B06_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B06_10m.tif new file mode 100644 index 00000000000..1b928649ff8 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B06_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B07_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B07_10m.tif new file mode 100644 index 00000000000..5461a29d150 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B07_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B08_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B08_10m.tif new file mode 100644 index 00000000000..693ba2275a2 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B08_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B09_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B09_10m.tif new file mode 100644 index 00000000000..b355cbaf76e Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B09_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B11_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B11_10m.tif new file mode 100644 index 00000000000..951b389bf58 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B11_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B12_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B12_10m.tif new file mode 100644 index 00000000000..75d5dc006a0 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B12_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B8A_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B8A_10m.tif new file mode 100644 index 00000000000..d61d70f132a Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_05_04/12_2017_05_04_B8A_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B01_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B01_10m.tif new file mode 100644 index 00000000000..d8a65810b77 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B01_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B02_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B02_10m.tif new file mode 100644 index 00000000000..ba7a8be469c Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B02_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B03_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B03_10m.tif new file mode 100644 index 00000000000..942dfad8bcc Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B03_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B04_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B04_10m.tif new file mode 100644 index 00000000000..1b45397b5a5 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B04_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B05_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B05_10m.tif new file mode 100644 index 00000000000..aafe970de78 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B05_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B06_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B06_10m.tif new file mode 100644 index 00000000000..ae17f0113a8 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B06_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B07_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B07_10m.tif new file mode 100644 index 00000000000..77fec63591b Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B07_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B08_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B08_10m.tif new file mode 100644 index 00000000000..63b3a18668c Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B08_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B09_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B09_10m.tif new file mode 100644 index 00000000000..8492de826c1 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B09_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B11_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B11_10m.tif new file mode 100644 index 00000000000..e0fb137c96b Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B11_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B12_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B12_10m.tif new file mode 100644 index 00000000000..f46dc21c612 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B12_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B8A_10m.tif b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B8A_10m.tif new file mode 100644 index 00000000000..0e9b1630714 Binary files /dev/null and b/tests/data/south_africa_crop_type/train/imagery/s2/12/2017_07_22/12_2017_07_22_B8A_10m.tif differ diff --git a/tests/data/south_africa_crop_type/train/labels/12.tif b/tests/data/south_africa_crop_type/train/labels/12.tif new file mode 100644 index 00000000000..82ee8fdf66b Binary files /dev/null and b/tests/data/south_africa_crop_type/train/labels/12.tif differ diff --git a/tests/datasets/test_south_africa_crop_type.py b/tests/datasets/test_south_africa_crop_type.py new file mode 100644 index 00000000000..7422395f0b5 --- /dev/null +++ b/tests/datasets/test_south_africa_crop_type.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from rasterio.crs import CRS + +from torchgeo.datasets import ( + BoundingBox, + DatasetNotFoundError, + IntersectionDataset, + RGBBandsMissingError, + SouthAfricaCropType, + UnionDataset, +) + + +class TestSouthAfricaCropType: + @pytest.fixture + def dataset(self) -> SouthAfricaCropType: + path = os.path.join("tests", "data", "south_africa_crop_type") + transforms = nn.Identity() + return SouthAfricaCropType(paths=path, transforms=transforms) + + def test_getitem(self, dataset: SouthAfricaCropType) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + + def test_and(self, dataset: SouthAfricaCropType) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: SouthAfricaCropType) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_downloaded(self, dataset: SouthAfricaCropType) -> None: + SouthAfricaCropType(paths=dataset.paths) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + SouthAfricaCropType(str(tmp_path)) + + def test_plot(self, dataset: SouthAfricaCropType) -> None: + x = dataset[dataset.bounds] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: SouthAfricaCropType) -> None: + x = dataset[dataset.bounds] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() + + def test_invalid_query(self, dataset: SouthAfricaCropType) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] + + def test_rgb_bands_absent_plot(self, dataset: SouthAfricaCropType) -> None: + with pytest.raises( + RGBBandsMissingError, match="Dataset does not contain some of the RGB bands" + ): + ds = SouthAfricaCropType(dataset.paths, bands=["B01", "B02", "B05"]) + x = ds[ds.bounds] + ds.plot(x, suptitle="Test") + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 71eb3f168dc..5f3f974e2b2 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -99,6 +99,7 @@ from .sentinel import Sentinel, Sentinel1, Sentinel2 from .skippd import SKIPPD from .so2sat import So2Sat +from .south_africa_crop_type import SouthAfricaCropType from .south_america_soybean import SouthAmericaSoybean from .spacenet import ( SpaceNet, @@ -188,6 +189,7 @@ "Sentinel", "Sentinel1", "Sentinel2", + "SouthAfricaCropType", "SouthAmericaSoybean", # NonGeoDataset "ADVANCE", diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py new file mode 100644 index 00000000000..8165c239645 --- /dev/null +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -0,0 +1,278 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""South Africa Crop Type Competition Dataset.""" + +import os +import re +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union, cast + +import matplotlib.pyplot as plt +import torch +from matplotlib.figure import Figure +from rasterio.crs import CRS +from torch import Tensor + +from .geo import RasterDataset +from .utils import BoundingBox, RGBBandsMissingError + + +class SouthAfricaCropType(RasterDataset): + """South Africa Crop Type Challenge dataset. + + The `South Africa Crop Type Challenge + `__ + dataset includes satellite imagery from Sentinel-1 and Sentinel-2 and labels for + crop type that were collected by aerial and vehicle survey from May 2017 to March + 2018. Data was provided by the Western Cape Department of Agriculture and is + available via the Radiant Earth Foundation. For each field id the dataset contains + time series imagery and a single label mask. Since TorchGeo does not yet support + timeseries datasets, the first available imagery in July will be returned for each + field. Note that the dates for S1 and S2 imagery for a given field are not + guaranteed to be the same. Each pixel in the label contains an integer field number + and crop type class. + + Dataset format: + + * images are 2-band Sentinel 1 and 12-band Sentinel-2 data with a cloud mask + * masks are tiff images with unique values representing the class and field id. + + Dataset classes: + + 0. No Data + 1. Lucerne/Medics + 2. Planted pastures (perennial) + 3. Fallow + 4. Wine grapes + 5. Weeds + 6. Small grain grazing + 7. Wheat + 8. Canola + 9. Rooibos + + If you use this dataset in your research, please cite the following dataset: + + * Western Cape Department of Agriculture, Radiant Earth Foundation (2021) + "Crop Type Classification Dataset for Western Cape, South Africa", + Version 1.0, Radiant MLHub, https://doi.org/10.34911/rdnt.j0co8q + + .. versionadded:: 0.6 + """ + + filename_regex = r""" + ^(?P[0-9]*) + _(?P[0-9]{4}_[0-9]{2}_[0-9]{2}) + _(?P(B[0-9A-Z]{2} | VH | VV)) + _10m""" + date_format = "%Y_%m_%d" + rgb_bands = ["B04", "B03", "B02"] + s1_bands = ["VH", "VV"] + s2_bands = [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12", + ] + all_bands: list[str] = s1_bands + s2_bands + cmap = { + 0: (0, 0, 0, 255), + 1: (255, 211, 0, 255), + 2: (255, 37, 37, 255), + 3: (0, 168, 226, 255), + 4: (255, 158, 9, 255), + 5: (37, 111, 0, 255), + 6: (255, 255, 0, 255), + 7: (222, 166, 9, 255), + 8: (111, 166, 0, 255), + 9: (0, 175, 73, 255), + } + + def __init__( + self, + paths: Union[str, Iterable[str]] = "data", + crs: Optional[CRS] = None, + classes: list[int] = list(cmap.keys()), + bands: list[str] = all_bands, + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + ) -> None: + """Initialize a new South Africa Crop Type dataset instance. + + Args: + paths: paths directory where dataset can be found + crs: coordinate reference system to be used + classes: crop type classes to be included + bands: the subset of bands to load + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert ( + set(classes) <= self.cmap.keys() + ), f"Only the following classes are valid: {list(self.cmap.keys())}." + assert 0 in classes, "Classes must include the background class: 0" + + self.paths = paths + self.classes = classes + self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) + self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8) + + super().__init__(paths=paths, crs=crs, bands=bands, transforms=transforms) + + # Map chosen classes to ordinal numbers, all others mapped to background class + for v, k in enumerate(self.classes): + self.ordinal_map[k] = v + self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and labels at that index + """ + assert isinstance(self.paths, str) + + # Get all files matching the given query + hits = self.index.intersection(tuple(query), objects=True) + filepaths = cast(list[str], [hit.object for hit in hits]) + + if not filepaths: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + + data_list: list[Tensor] = [] + filename_regex = re.compile(self.filename_regex, re.VERBOSE) + + # Loop through matched filepaths and find all unique field ids + field_ids: list[str] = [] + # Store date in July for s1 and s2 we want to use for each sample + imagery_dates: dict[str, dict[str, str]] = {} + + for filepath in filepaths: + filename = os.path.basename(filepath) + match = re.match(filename_regex, filename) + if match: + field_id = match.group("field_id") + date = match.group("date") + band = match.group("band") + band_type = "s1" if band in self.s1_bands else "s2" + if field_id not in field_ids: + field_ids.append(field_id) + imagery_dates[field_id] = {"s1": "", "s2": ""} + if ( + date.split("_")[1] == "07" + and not imagery_dates[field_id][band_type] + ): + imagery_dates[field_id][band_type] = date + + # Create Tensors for each band using stored dates + for band in self.bands: + band_type = "s1" if band in self.s1_bands else "s2" + band_filepaths = [] + for field_id in field_ids: + date = imagery_dates[field_id][band_type] + filepath = os.path.join( + self.paths, + "train", + "imagery", + band_type, + field_id, + date, + f"{field_id}_{date}_{band}_10m.tif", + ) + band_filepaths.append(filepath) + data_list.append(self._merge_files(band_filepaths, query)) + image = torch.cat(data_list) + + # Add labels for each field + mask_filepaths: list[str] = [] + for field_id in field_ids: + file_path = filepath = os.path.join( + self.paths, "train", "labels", f"{field_id}.tif" + ) + mask_filepaths.append(file_path) + + mask = self._merge_files(mask_filepaths, query) + + sample = { + "crs": self.crs, + "bbox": query, + "image": image.float(), + "mask": mask.long(), + } + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + Raises: + RGBBandsMissingError: If *bands* does not include all RGB bands. + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise RGBBandsMissingError() + + image = sample["image"][rgb_indices].permute(1, 2, 0) + image = (image - image.min()) / (image.max() - image.min()) + + mask = sample["mask"].squeeze() + ncols = 2 + + showing_prediction = "prediction" in sample + if showing_prediction: + pred = sample["prediction"].squeeze() + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow(self.ordinal_cmap[mask], interpolation="none") + axs[1].axis("off") + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + + if showing_prediction: + axs[2].imshow(pred) + axs[2].axis("off") + if show_titles: + axs[2].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig