Skip to content

Commit

Permalink
sentinel2nccm datamodule on new branch (#1950)
Browse files Browse the repository at this point in the history
* sentinel2nccm datamodule

* Fixed style errors

* added 2019 to sentinel2, removed 2022 from nccm

* fixed error

* Use matching split size

---------

Co-authored-by: shreya28 <“[email protected]”>
Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
3 people authored Mar 22, 2024
1 parent 5a7b9e5 commit bd48efe
Show file tree
Hide file tree
Showing 98 changed files with 198 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Sentinel
^^^^^^^^

.. autoclass:: Sentinel2CDLDataModule
.. autoclass:: Sentinel2NCCMDataModule

Non-geospatial DataModules
--------------------------
Expand Down
18 changes: 18 additions & 0 deletions tests/conf/sentinel2_nccm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 13
num_classes: 5
num_filters: 1
ignore_index: 4
data:
class_path: Sentinel2NCCMDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
nccm_paths: "tests/data/nccm"
sentinel2_paths: "tests/data/sentinel2"
Binary file modified tests/data/nccm/CDL2017_clip.tif
Binary file not shown.
Binary file modified tests/data/nccm/CDL2018_clip1.tif
Binary file not shown.
Binary file modified tests/data/nccm/CDL2019_clip.tif
Binary file not shown.
13 changes: 3 additions & 10 deletions tests/data/nccm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio.crs import CRS
from rasterio.transform import Affine

SIZE = 32
SIZE = 128

np.random.seed(0)
files = ["CDL2017_clip.tif", "CDL2018_clip1.tif", "CDL2019_clip.tif"]
Expand All @@ -23,15 +23,8 @@ def create_file(path: str, dtype: str):
"driver": "GTiff",
"dtype": dtype,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
8.983152841195208e-05,
0.0,
115.483402043364,
0.0,
-8.983152841195208e-05,
53.531397320113605,
),
"crs": CRS.from_epsg(32616),
"transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0),
"height": SIZE,
"width": SIZE,
"compress": "lzw",
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
50 changes: 50 additions & 0 deletions tests/data/sentinel2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@
"T16TFM_20220412T162841_B12.jp2",
"T16TFM_20220412T162841_B8A.jp2",
"T16TFM_20220412T162841_TCI.jp2",
"T16TFM_20190412T162841_B01.jp2",
"T16TFM_20190412T162841_B02.jp2",
"T16TFM_20190412T162841_B03.jp2",
"T16TFM_20190412T162841_B04.jp2",
"T16TFM_20190412T162841_B05.jp2",
"T16TFM_20190412T162841_B06.jp2",
"T16TFM_20190412T162841_B07.jp2",
"T16TFM_20190412T162841_B08.jp2",
"T16TFM_20190412T162841_B09.jp2",
"T16TFM_20190412T162841_B10.jp2",
"T16TFM_20190412T162841_B11.jp2",
"T16TFM_20190412T162841_B12.jp2",
"T16TFM_20190412T162841_B8A.jp2",
"T16TFM_20190412T162841_TCI.jp2",
]
}
}
Expand All @@ -54,6 +68,13 @@
"T26EMU_20220414T110751_B08_10m.jp2",
"T26EMU_20220414T110751_TCI_10m.jp2",
"T26EMU_20220414T110751_WVP_10m.jp2",
"T26EMU_20190414T110751_AOT_10m.jp2",
"T26EMU_20190414T110751_B02_10m.jp2",
"T26EMU_20190414T110751_B03_10m.jp2",
"T26EMU_20190414T110751_B04_10m.jp2",
"T26EMU_20190414T110751_B08_10m.jp2",
"T26EMU_20190414T110751_TCI_10m.jp2",
"T26EMU_20190414T110751_WVP_10m.jp2",
],
"R20m": [
"T26EMU_20220414T110751_AOT_20m.jp2",
Expand All @@ -70,6 +91,20 @@
"T26EMU_20220414T110751_SCL_20m.jp2",
"T26EMU_20220414T110751_TCI_20m.jp2",
"T26EMU_20220414T110751_WVP_20m.jp2",
"T26EMU_20190414T110751_AOT_20m.jp2",
"T26EMU_20190414T110751_B01_20m.jp2",
"T26EMU_20190414T110751_B02_20m.jp2",
"T26EMU_20190414T110751_B03_20m.jp2",
"T26EMU_20190414T110751_B04_20m.jp2",
"T26EMU_20190414T110751_B05_20m.jp2",
"T26EMU_20190414T110751_B06_20m.jp2",
"T26EMU_20190414T110751_B07_20m.jp2",
"T26EMU_20190414T110751_B11_20m.jp2",
"T26EMU_20190414T110751_B12_20m.jp2",
"T26EMU_20190414T110751_B8A_20m.jp2",
"T26EMU_20190414T110751_SCL_20m.jp2",
"T26EMU_20190414T110751_TCI_20m.jp2",
"T26EMU_20190414T110751_WVP_20m.jp2",
],
"R60m": [
"T26EMU_20220414T110751_AOT_60m.jp2",
Expand All @@ -87,6 +122,21 @@
"T26EMU_20220414T110751_SCL_60m.jp2",
"T26EMU_20220414T110751_TCI_60m.jp2",
"T26EMU_20220414T110751_WVP_60m.jp2",
"T26EMU_20190414T110751_AOT_60m.jp2",
"T26EMU_20190414T110751_B01_60m.jp2",
"T26EMU_20190414T110751_B02_60m.jp2",
"T26EMU_20190414T110751_B03_60m.jp2",
"T26EMU_20190414T110751_B04_60m.jp2",
"T26EMU_20190414T110751_B05_60m.jp2",
"T26EMU_20190414T110751_B06_60m.jp2",
"T26EMU_20190414T110751_B07_60m.jp2",
"T26EMU_20190414T110751_B09_60m.jp2",
"T26EMU_20190414T110751_B11_60m.jp2",
"T26EMU_20190414T110751_B12_60m.jp2",
"T26EMU_20190414T110751_B8A_60m.jp2",
"T26EMU_20190414T110751_SCL_60m.jp2",
"T26EMU_20190414T110751_TCI_60m.jp2",
"T26EMU_20190414T110751_WVP_60m.jp2",
],
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def dataset(self) -> Sentinel2:
return Sentinel2(root, res=res, bands=bands, transforms=transforms)

def test_separate_files(self, dataset: Sentinel2) -> None:
assert dataset.index.count(dataset.index.bounds) == 2
assert dataset.index.count(dataset.index.bounds) == 4

def test_getitem(self, dataset: Sentinel2) -> None:
x = dataset[dataset.bounds]
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s2_all",
"sen12ms_s2_reduced",
"sentinel2_cdl",
"sentinel2_nccm",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .seco import SeasonalContrastS2DataModule
from .sen12ms import SEN12MSDataModule
from .sentinel2_cdl import Sentinel2CDLDataModule
from .sentinel2_nccm import Sentinel2NCCMDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
Expand All @@ -49,6 +50,7 @@
"L8BiomeDataModule",
"NAIPChesapeakeDataModule",
"Sentinel2CDLDataModule",
"Sentinel2NCCMDataModule",
# NonGeoDataset
"BigEarthNetDataModule",
"ChaBuDDataModule",
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/sentinel2_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def setup(self, stage: str) -> None:

(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.10, 0.10], grid_size=8, generator=generator
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
)
)
if stage in ["fit"]:
Expand Down
121 changes: 121 additions & 0 deletions torchgeo/datamodules/sentinel2_nccm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Sentinel-2 and NCCM datamodule."""

from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample
from matplotlib.figure import Figure

from ..datasets import NCCM, Sentinel2, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class Sentinel2NCCMDataModule(GeoDataModule):
"""LightningDataModule implementation for the Sentinel-2 and NCCM dataset.
.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new Sentinel2NCCMDataModule instance.
Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.NCCM` (prefix keys with ``nccm_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
# Define prefix for NCCM and Sentinel-2 arguments
nccm_signature = "nccm_"
sentinel2_signature = "sentinel2_"
self.nccm_kwargs = {}
self.sentinel2_kwargs = {}

for key, val in kwargs.items():
# Check if the current key starts with the NCCM prefix
if key.startswith(nccm_signature):
# If so, extract the key-value pair to the NCCM dictionary
self.nccm_kwargs[key[len(nccm_signature) :]] = val
# Check if the current key starts with the Sentinel-2 prefix
elif key.startswith(sentinel2_signature):
# If so, extract the key-value pair to the Sentinel-2 dictionary
self.sentinel2_kwargs[key[len(sentinel2_signature) :]] = val

super().__init__(
NCCM, batch_size, patch_size, length, num_workers, **self.nccm_kwargs
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.nccm = NCCM(**self.nccm_kwargs)
self.dataset = self.sentinel2 & self.nccm

generator = torch.Generator().manual_seed(0)

(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
)
)
if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)

def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run NCCM plot method.
Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
return self.nccm.plot(*args, **kwargs)

0 comments on commit bd48efe

Please sign in to comment.