Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AgriFieldNet datamodule #1873

Merged
merged 6 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ torchgeo.datamodules
Geospatial DataModules
----------------------

AgriFieldNet
^^^^^^^^^^^^

.. autoclass:: AgriFieldNetDataModule

Chesapeake Land Cover
^^^^^^^^^^^^^^^^^^^^^

Expand Down
17 changes: 17 additions & 0 deletions tests/conf/agrifieldnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 12
num_classes: 14
num_filters: 1
ignore_index: 0
data:
class_path: AgriFieldNetDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
paths: "tests/data/agrifieldnet"
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TestSemanticSegmentationTask:
@pytest.mark.parametrize(
"name",
[
"agrifieldnet",
"chabud",
"chesapeake_cvpr_5",
"chesapeake_cvpr_7",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""TorchGeo datamodules."""

from .agrifieldnet import AgriFieldNetDataModule
from .bigearthnet import BigEarthNetDataModule
from .chabud import ChaBuDDataModule
from .chesapeake import ChesapeakeCVPRDataModule
Expand Down Expand Up @@ -43,6 +44,7 @@

__all__ = (
# GeoDataset
"AgriFieldNetDataModule",
"ChesapeakeCVPRDataModule",
"L7IrishDataModule",
"L8BiomeDataModule",
Expand Down
86 changes: 86 additions & 0 deletions torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""AgriFieldNet datamodule."""

from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample

from ..datasets import AgriFieldNet, random_bbox_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class AgriFieldNetDataModule(GeoDataModule):
"""LightningDataModule implementation for the AgriFieldNet dataset.

.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 1,
patch_size: Union[int, tuple[int, int]] = 256,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new AgriFieldNetDataModule instance.

Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.AgriFieldNet`.
"""
super().__init__(
AgriFieldNet,
batch_size=batch_size,
patch_size=patch_size,
length=length,
num_workers=num_workers,
**kwargs,
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

def setup(self, stage: str) -> None:
"""Set up datasets.

Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
dataset = AgriFieldNet(**self.kwargs)
generator = torch.Generator().manual_seed(0)
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)

if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)
Loading