diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 00c0fb998bd..ab8b4d5236d 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -135,6 +135,11 @@ So2Sat .. autoclass:: So2SatDataModule +South America Soybean +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: SouthAmericaSoybeanSentinel2DataModule + SpaceNet ^^^^^^^^ diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean.zip b/tests/data/south_america_soybean/SouthAmericaSoybean.zip index 5453b89fc25..5390e1914dc 100644 Binary files a/tests/data/south_america_soybean/SouthAmericaSoybean.zip and b/tests/data/south_america_soybean/SouthAmericaSoybean.zip differ diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index fbe7d7b23d1..4d679022788 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -11,7 +11,7 @@ from rasterio.crs import CRS from rasterio.transform import Affine -SIZE = 32 +SIZE = 36 np.random.seed(0) @@ -24,15 +24,8 @@ def create_file(path: str, dtype: str): "driver": "GTiff", "dtype": dtype, "count": 1, - "crs": CRS.from_epsg(4326), - "transform": Affine( - 0.0002499999999999943131, - 0.0, - -82.0005000000000024, - 0.0, - -0.0002499999999999943131, - 0.0005000000000000, - ), + "crs": CRS.from_epsg(32616), + "transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0), "height": SIZE, "width": SIZE, "compress": "lzw", diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index e829f818bd4..e49d17a1024 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -72,6 +72,7 @@ class TestSemanticSegmentationTask: "sen12ms_s1", "sen12ms_s2_all", "sen12ms_s2_reduced", + "south_america_soybean_s2", "spacenet1", "ssl4eo_l_benchmark_cdl", "ssl4eo_l_benchmark_nlcd", diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 6036fa699b2..34880e3bade 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -30,7 +30,7 @@ from .sen12ms import SEN12MSDataModule from .skippd import SKIPPDDataModule from .so2sat import So2SatDataModule -from .south_america_soybean import SouthAmericaSoybean +from .south_america_soybean import SouthAmericaSoybeanSentinel2DataModule from .spacenet import SpaceNet1DataModule from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule from .ssl4eo_benchmark import SSL4EOLBenchmarkDataModule diff --git a/torchgeo/datamodules/south_america_soybean.py b/torchgeo/datamodules/south_america_soybean.py index 9a191bd46ec..c274bf2dd79 100644 --- a/torchgeo/datamodules/south_america_soybean.py +++ b/torchgeo/datamodules/south_america_soybean.py @@ -8,15 +8,15 @@ import kornia.augmentation as K from matplotlib.figure import Figure -from ..datasets import SouthAmericaSoybean, BoundingBox, Sentinel2 +from ..datasets import BoundingBox, Sentinel2, SouthAmericaSoybean from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..transforms import AugmentationSequential from .geo import GeoDataModule class SouthAmericaSoybeanSentinel2DataModule(GeoDataModule): - """LightningDataModule implementation for SouthAmericaSoybean and Sentinel2 datasets. - + """LightningDataModule for SouthAmericaSoybean and Sentinel2 datasets. + Uses the train/val/test splits from the dataset. """ @@ -36,25 +36,26 @@ def __init__( length: Length of each training epoch. num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.SouthAmericaSoybean` (prefix keys with ``south_america_soybean_``) and + :class:`~torchgeo.datasets.SouthAmericaSoybean` + (prefix keys with ``south_america_soybean_``) and :class:`~torchgeo.datasets.Sentinel2` (prefix keys with ``sentinel2_``). """ - self.southamericasoybean_kwargs = {} + self.south_america_soybean_kwargs = {} self.sentinel2_kwargs = {} for key, val in kwargs.items(): if key.startswith("south_america_soybean_"): - self.southamericasoybean_kwargs[key[22:]] = val + self.south_america_soybean_kwargs[key[22:]] = val elif key.startswith("sentinel2_"): self.sentinel2_kwargs[key[10:]] = val super().__init__( SouthAmericaSoybean, - batch_size, - patch_size, - length, - num_workers, - **self.south_america_soybean_kwargs, + batch_size=batch_size, + patch_size=patch_size, + length=length, + num_workers=num_workers, + **kwargs, ) self.aug = AugmentationSequential( @@ -68,7 +69,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ self.sentinel2 = Sentinel2(**self.sentinel2_kwargs) - self.south_america_soybean = SouthAmericaSoybean(**self.eurocrops_kwargs) + self.south_america_soybean = SouthAmericaSoybean( + **self.south_america_soybean_kwargs + ) self.dataset = self.sentinel2 & self.south_america_soybean roi = self.dataset.bounds @@ -101,9 +104,9 @@ def plot(self, *args: Any, **kwargs: Any) -> Figure: Args: *args: Arguments passed to plot method. **kwargs: Keyword arguments passed to plot method. + Returns: A matplotlib Figure with the image, ground truth, and predictions. - .. versionadded:: 0.4 + .. versionadded:: 0.6 """ - return self.south_america_soybean.plot(*args, **kwargs)