From f883dd9a4ab17eb433b377abd92b6af6c790d603 Mon Sep 17 00:00:00 2001 From: Yi-Chia Chang Date: Thu, 8 Feb 2024 22:37:35 -0600 Subject: [PATCH 1/6] add agrifieldnet datamodule --- tests/conf/agrifieldnet.yaml | 28 +++++++++ torchgeo/datamodules/__init__.py | 2 + torchgeo/datamodules/agrifieldnet.py | 86 ++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 tests/conf/agrifieldnet.yaml create mode 100644 torchgeo/datamodules/agrifieldnet.py diff --git a/tests/conf/agrifieldnet.yaml b/tests/conf/agrifieldnet.yaml new file mode 100644 index 00000000000..603b42d6273 --- /dev/null +++ b/tests/conf/agrifieldnet.yaml @@ -0,0 +1,28 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: "ce" + model: "unet" + backbone: "resnet18" + in_channels: 12 + num_classes: 14 + num_filters: 1 + ignore_index: 0 +data: + class_path: AgriFieldNetDataModule + init_args: + batch_size: 2 + patch_size: 16 + dict_kwargs: + paths: "tests/data/agrifieldnet" + + +data: + class_path: L7IrishDataModule + init_args: + batch_size: 1 + patch_size: 32 + length: 5 + dict_kwargs: + paths: "tests/data/l7irish" + download: true diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 1f05f31b3b4..00e6a23dd08 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -3,6 +3,7 @@ """TorchGeo datamodules.""" +from .agrifieldnet import AgriFieldNetDataModule from .bigearthnet import BigEarthNetDataModule from .chabud import ChaBuDDataModule from .chesapeake import ChesapeakeCVPRDataModule @@ -43,6 +44,7 @@ __all__ = ( # GeoDataset + "AgriFieldNetDataModule", "ChesapeakeCVPRDataModule", "L7IrishDataModule", "L8BiomeDataModule", diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py new file mode 100644 index 00000000000..5e6da6aafe0 --- /dev/null +++ b/torchgeo/datamodules/agrifieldnet.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""AgriFieldNet datamodule.""" + +from typing import Any, Optional, Union + +import kornia.augmentation as K +import torch +from kornia.constants import DataKey, Resample + +from ..datasets import AgriFieldNet, random_bbox_assignment +from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from .geo import GeoDataModule + + +class AgriFieldNetDataModule(GeoDataModule): + """LightningDataModule implementation for the AgriFieldNet dataset. + + .. versionadded:: 0.6 + """ + + def __init__( + self, + batch_size: int = 1, + patch_size: Union[int, tuple[int, int]] = 256, + length: Optional[int] = None, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new AgriFieldNetDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.AgriFieldNet`. + """ + super().__init__( + AgriFieldNet, + batch_size=batch_size, + patch_size=patch_size, + length=length, + num_workers=num_workers, + **kwargs, + ) + + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), + K.RandomVerticalFlip(p=0.5), + K.RandomHorizontalFlip(p=0.5), + data_keys=["image", "mask"], + extra_args={ + DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + }, + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + dataset = AgriFieldNet(**self.kwargs) + generator = torch.Generator().manual_seed(0) + (self.train_dataset, self.val_dataset, self.test_dataset) = ( + random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator) + ) + + if stage in ["fit"]: + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, self.patch_size, self.batch_size, self.length + ) + if stage in ["fit", "validate"]: + self.val_sampler = GridGeoSampler( + self.val_dataset, self.patch_size, self.patch_size + ) + if stage in ["test"]: + self.test_sampler = GridGeoSampler( + self.test_dataset, self.patch_size, self.patch_size + ) From dc0cb8983d6d9b7c2dd5bac8d6769275af2ba614 Mon Sep 17 00:00:00 2001 From: Yi-Chia Chang Date: Mon, 12 Feb 2024 23:54:44 -0500 Subject: [PATCH 2/6] fix codecov --- docs/api/datamodules.rst | 5 +++++ tests/conf/agrifieldnet.yaml | 11 ----------- tests/trainers/test_segmentation.py | 1 + 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 00c0fb998bd..4cc2ff0bd99 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -6,6 +6,11 @@ torchgeo.datamodules Geospatial DataModules ---------------------- +AgriFieldNet +^^^^^^^^^^^^ + +.. autoclass:: AgriFieldNetDataModule + Chesapeake Land Cover ^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/conf/agrifieldnet.yaml b/tests/conf/agrifieldnet.yaml index 603b42d6273..42f2550910a 100644 --- a/tests/conf/agrifieldnet.yaml +++ b/tests/conf/agrifieldnet.yaml @@ -15,14 +15,3 @@ data: patch_size: 16 dict_kwargs: paths: "tests/data/agrifieldnet" - - -data: - class_path: L7IrishDataModule - init_args: - batch_size: 1 - patch_size: 32 - length: 5 - dict_kwargs: - paths: "tests/data/l7irish" - download: true diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index e829f818bd4..6831d94a649 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -55,6 +55,7 @@ class TestSemanticSegmentationTask: @pytest.mark.parametrize( "name", [ + "agrifieldnet", "chabud", "chesapeake_cvpr_5", "chesapeake_cvpr_7", From f690d8b1f86de9e2ba8fdda413c9d8264b169980 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 15 Mar 2024 13:29:04 +0100 Subject: [PATCH 3/6] extra_args not needed --- torchgeo/datamodules/agrifieldnet.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index 5e6da6aafe0..2fd45fb2618 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -7,7 +7,6 @@ import kornia.augmentation as K import torch -from kornia.constants import DataKey, Resample from ..datasets import AgriFieldNet, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler @@ -55,9 +54,6 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=["image", "mask"], - extra_args={ - DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} - }, ) def setup(self, stage: str) -> None: From 7cfc26eb647594399f93373a7dbdf3174be4781f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 15 Mar 2024 13:31:44 +0100 Subject: [PATCH 4/6] Bigger default batch size --- torchgeo/datamodules/agrifieldnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index 2fd45fb2618..787a558ecf4 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -23,7 +23,7 @@ class AgriFieldNetDataModule(GeoDataModule): def __init__( self, - batch_size: int = 1, + batch_size: int = 64, patch_size: Union[int, tuple[int, int]] = 256, length: Optional[int] = None, num_workers: int = 0, From 5c53b4e11f7413c94cbc860cc0836befe87782e6 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 15 Mar 2024 13:39:46 +0100 Subject: [PATCH 5/6] Revert "extra_args not needed" This reverts commit f690d8b1f86de9e2ba8fdda413c9d8264b169980. --- torchgeo/datamodules/agrifieldnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index 787a558ecf4..d1e8db8560b 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -7,6 +7,7 @@ import kornia.augmentation as K import torch +from kornia.constants import DataKey, Resample from ..datasets import AgriFieldNet, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler @@ -54,6 +55,9 @@ def __init__( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), data_keys=["image", "mask"], + extra_args={ + DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + }, ) def setup(self, stage: str) -> None: From d39fb88732428e045b23a1a1b4f5b56cd231e076 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 22 Mar 2024 21:56:20 +0100 Subject: [PATCH 6/6] Same split as everyone else --- torchgeo/datamodules/agrifieldnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index d1e8db8560b..f7b27d25207 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -69,7 +69,7 @@ def setup(self, stage: str) -> None: dataset = AgriFieldNet(**self.kwargs) generator = torch.Generator().manual_seed(0) (self.train_dataset, self.val_dataset, self.test_dataset) = ( - random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator) + random_bbox_assignment(dataset, [0.8, 0.1, 0.1], generator) ) if stage in ["fit"]: