Skip to content

Commit

Permalink
Updated datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
cookie-kyu committed Feb 16, 2024
1 parent 524ada2 commit e8c94b9
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 25 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ So2Sat

.. autoclass:: So2SatDataModule

South America Soybean
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: SouthAmericaSoybeanSentinel2DataModule

SpaceNet
^^^^^^^^

Expand Down
Binary file modified tests/data/south_america_soybean/SouthAmericaSoybean.zip
Binary file not shown.
13 changes: 3 additions & 10 deletions tests/data/south_america_soybean/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio.crs import CRS
from rasterio.transform import Affine

SIZE = 32
SIZE = 36


np.random.seed(0)
Expand All @@ -24,15 +24,8 @@ def create_file(path: str, dtype: str):
"driver": "GTiff",
"dtype": dtype,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
0.0002499999999999943131,
0.0,
-82.0005000000000024,
0.0,
-0.0002499999999999943131,
0.0005000000000000,
),
"crs": CRS.from_epsg(32616),
"transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0),
"height": SIZE,
"width": SIZE,
"compress": "lzw",
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s1",
"sen12ms_s2_all",
"sen12ms_s2_reduced",
"south_america_soybean_s2",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .sen12ms import SEN12MSDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .south_america_soybean import SouthAmericaSoybean
from .south_america_soybean import SouthAmericaSoybeanSentinel2DataModule
from .spacenet import SpaceNet1DataModule
from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule
from .ssl4eo_benchmark import SSL4EOLBenchmarkDataModule
Expand Down
31 changes: 17 additions & 14 deletions torchgeo/datamodules/south_america_soybean.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import kornia.augmentation as K
from matplotlib.figure import Figure

from ..datasets import SouthAmericaSoybean, BoundingBox, Sentinel2
from ..datasets import BoundingBox, Sentinel2, SouthAmericaSoybean
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class SouthAmericaSoybeanSentinel2DataModule(GeoDataModule):
"""LightningDataModule implementation for SouthAmericaSoybean and Sentinel2 datasets.
"""LightningDataModule for SouthAmericaSoybean and Sentinel2 datasets.
Uses the train/val/test splits from the dataset.
"""

Expand All @@ -36,25 +36,26 @@ def __init__(
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SouthAmericaSoybean` (prefix keys with ``south_america_soybean_``) and
:class:`~torchgeo.datasets.SouthAmericaSoybean`
(prefix keys with ``south_america_soybean_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
self.southamericasoybean_kwargs = {}
self.south_america_soybean_kwargs = {}
self.sentinel2_kwargs = {}
for key, val in kwargs.items():
if key.startswith("south_america_soybean_"):
self.southamericasoybean_kwargs[key[22:]] = val
self.south_america_soybean_kwargs[key[22:]] = val
elif key.startswith("sentinel2_"):
self.sentinel2_kwargs[key[10:]] = val

super().__init__(
SouthAmericaSoybean,
batch_size,
patch_size,
length,
num_workers,
**self.south_america_soybean_kwargs,
batch_size=batch_size,
patch_size=patch_size,
length=length,
num_workers=num_workers,
**kwargs,
)

self.aug = AugmentationSequential(
Expand All @@ -68,7 +69,9 @@ def setup(self, stage: str) -> None:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.south_america_soybean = SouthAmericaSoybean(**self.eurocrops_kwargs)
self.south_america_soybean = SouthAmericaSoybean(
**self.south_america_soybean_kwargs
)
self.dataset = self.sentinel2 & self.south_america_soybean

roi = self.dataset.bounds
Expand Down Expand Up @@ -101,9 +104,9 @@ def plot(self, *args: Any, **kwargs: Any) -> Figure:
Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
.. versionadded:: 0.4
.. versionadded:: 0.6
"""

return self.south_america_soybean.plot(*args, **kwargs)

0 comments on commit e8c94b9

Please sign in to comment.