Skip to content

Commit

Permalink
Add Sentinel2_CDL Datamodule (#1889)
Browse files Browse the repository at this point in the history
* cdlsentinel2

* update kwargs

* style

* arg type

* add cov

* kwargs

* update cdl data.py for intersection

* style

* create 2022 cdl for intersection

* test roi method

* style

* test_cdl year update

* intersection

* random_grid_cell_assignment

* add comments and line

* add description

* add doc

* Update SIZE variable in sentinel2/data.py and test stage in datamodules/cdlsentinel2.py

* merge val_aug and test_aug to aug

* rename cdlsentinel2 to sentinel2cdl

* fix isort

* No need to monkeypatch CDL

* Smaller backbone == faster tests

* Sort docs alphabetically

* Smaller Sentinel-2 test files

* Smaller CDL files, don't delete directory

* Sort imports alphabetically

* Fix doc names

* extra_args not needed

* center crop doesn't do anything

* blacken

* Revert "extra_args not needed"

This reverts commit 859f24e.

* Add underscore to filename

* Add plot method

* import Figure

* split 80-10-10

* style

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
yichiac and adamjstewart authored Mar 22, 2024
1 parent fe9ee15 commit 5a7b9e5
Show file tree
Hide file tree
Showing 70 changed files with 160 additions and 14 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ NAIP

.. autoclass:: NAIPChesapeakeDataModule

Sentinel
^^^^^^^^

.. autoclass:: Sentinel2CDLDataModule

Non-geospatial DataModules
--------------------------

Expand Down
18 changes: 18 additions & 0 deletions tests/conf/sentinel2_cdl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 13
num_classes: 134
num_filters: 1
ignore_index: 0
data:
class_path: Sentinel2CDLDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
cdl_paths: "tests/data/cdl"
sentinel2_paths: "tests/data/sentinel2"
Binary file removed tests/data/cdl/2020_30m_cdls.zip
Binary file not shown.
Binary file removed tests/data/cdl/2020_30m_cdls/2020_30m_cdls.tif
Binary file not shown.
Binary file removed tests/data/cdl/2020_30m_cdls/2020_30m_cdls.tif.ovr
Binary file not shown.
Binary file removed tests/data/cdl/2021_30m_cdls.zip
Binary file not shown.
Binary file removed tests/data/cdl/2021_30m_cdls/2021_30m_cdls.tif
Binary file not shown.
Binary file removed tests/data/cdl/2021_30m_cdls/2021_30m_cdls.tif.ovr
Binary file not shown.
Binary file added tests/data/cdl/2022_30m_cdls.zip
Binary file not shown.
Binary file added tests/data/cdl/2022_30m_cdls/2022_30m_cdls.tif
Binary file not shown.
Binary file not shown.
Binary file added tests/data/cdl/2023_30m_cdls.zip
Binary file not shown.
Binary file added tests/data/cdl/2023_30m_cdls/2023_30m_cdls.tif
Binary file not shown.
Binary file not shown.
11 changes: 5 additions & 6 deletions tests/data/cdl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

import numpy as np
import rasterio
from rasterio import Affine

SIZE = 32
SIZE = 128

np.random.seed(0)
random.seed(0)
Expand All @@ -22,8 +23,8 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["crs"] = "epsg:32616"
profile["transform"] = Affine(30, 0.0, 399960.0, 0.0, -30, 4500000.0)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
Expand All @@ -49,7 +50,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
src.write_colormap(1, cmap)


directories = ["2020_30m_cdls", "2021_30m_cdls"]
directories = ["2023_30m_cdls", "2022_30m_cdls"]
raster_extensions = [".tif", ".tif.ovr"]


Expand Down Expand Up @@ -77,5 +78,3 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")

shutil.rmtree(dir)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/data/sentinel2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 36
SIZE = 128

np.random.seed(0)

Expand Down
14 changes: 7 additions & 7 deletions tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL:
monkeypatch.setattr(torchgeo.datasets.cdl, "download_url", download_url)

md5s = {
2021: "e929beb9c8e59fa1d7b7f82e64edaae1",
2020: "e95c2d40ce0c261ed6ee0bd00b49e4b6",
2023: "3fbd3eecf92b8ce1ae35060ada463c6d",
2022: "826c6fd639d9cdd94a44302fbc5b76c3",
}
monkeypatch.setattr(CDL, "md5s", md5s)
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
Expand All @@ -48,7 +48,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL:
transforms=transforms,
download=True,
checksum=True,
years=[2020, 2021],
years=[2023, 2022],
)

def test_getitem(self, dataset: CDL) -> None:
Expand All @@ -60,7 +60,7 @@ def test_getitem(self, dataset: CDL) -> None:
def test_classes(self) -> None:
root = os.path.join("tests", "data", "cdl")
classes = list(CDL.cmap.keys())[:5]
ds = CDL(root, years=[2021], classes=classes)
ds = CDL(root, years=[2023], classes=classes)
sample = ds[ds.bounds]
mask = sample["mask"]
assert mask.max() < len(classes)
Expand All @@ -75,19 +75,19 @@ def test_or(self, dataset: CDL) -> None:

def test_full_year(self, dataset: CDL) -> None:
bbox = dataset.bounds
time = datetime(2021, 6, 1).timestamp()
time = datetime(2023, 6, 1).timestamp()
query = BoundingBox(bbox.minx, bbox.maxx, bbox.miny, bbox.maxy, time, time)
next(dataset.index.intersection(tuple(query)))

def test_already_extracted(self, dataset: CDL) -> None:
CDL(dataset.paths, years=[2020, 2021])
CDL(dataset.paths, years=[2023, 2022])

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")
root = str(tmp_path)
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
CDL(root, years=[2020, 2021])
CDL(root, years=[2023, 2022])

def test_invalid_year(self, tmp_path: Path) -> None:
with pytest.raises(
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s1",
"sen12ms_s2_all",
"sen12ms_s2_reduced",
"sentinel2_cdl",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .resisc45 import RESISC45DataModule
from .seco import SeasonalContrastS2DataModule
from .sen12ms import SEN12MSDataModule
from .sentinel2_cdl import Sentinel2CDLDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
Expand All @@ -47,6 +48,7 @@
"L7IrishDataModule",
"L8BiomeDataModule",
"NAIPChesapeakeDataModule",
"Sentinel2CDLDataModule",
# NonGeoDataset
"BigEarthNetDataModule",
"ChaBuDDataModule",
Expand Down
121 changes: 121 additions & 0 deletions torchgeo/datamodules/sentinel2_cdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Sentinel-2 and CDL datamodule."""

from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample
from matplotlib.figure import Figure

from ..datasets import CDL, Sentinel2, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class Sentinel2CDLDataModule(GeoDataModule):
"""LightningDataModule implementation for the Sentinel-2 and CDL datasets.
.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new Sentinel2CDLDataModule instance.
Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.CDL` (prefix keys with ``cdl_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
# Define prefix for Cropland Data Layer (CDL) and Sentinel-2 arguments
cdl_signature = "cdl_"
sentinel2_signature = "sentinel2_"
self.cdl_kwargs = {}
self.sentinel2_kwargs = {}

for key, val in kwargs.items():
# Check if the current key starts with the CDL prefix
if key.startswith(cdl_signature):
# If so, extract the key-value pair to the CDL dictionary
self.cdl_kwargs[key[len(cdl_signature) :]] = val
# Check if the current key starts with the Sentinel-2 prefix
elif key.startswith(sentinel2_signature):
# If so, extract the key-value pair to the Sentinel-2 dictionary
self.sentinel2_kwargs[key[len(sentinel2_signature) :]] = val

super().__init__(
CDL, batch_size, patch_size, length, num_workers, **self.cdl_kwargs
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.cdl = CDL(**self.cdl_kwargs)
self.dataset = self.sentinel2 & self.cdl

generator = torch.Generator().manual_seed(0)

(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.10, 0.10], grid_size=8, generator=generator
)
)
if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)

def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run CDL plot method.
Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
return self.cdl.plot(*args, **kwargs)

0 comments on commit 5a7b9e5

Please sign in to comment.