Skip to content

Commit

Permalink
Add AgriFieldNet datamodule (#1873)
Browse files Browse the repository at this point in the history
* add agrifieldnet datamodule

* fix codecov

* extra_args not needed

* Bigger default batch size

* Revert "extra_args not needed"

This reverts commit f690d8b.

* Same split as everyone else

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
yichiac and adamjstewart authored Mar 22, 2024
1 parent bd48efe commit d030044
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ torchgeo.datamodules
Geospatial DataModules
----------------------

AgriFieldNet
^^^^^^^^^^^^

.. autoclass:: AgriFieldNetDataModule

Chesapeake Land Cover
^^^^^^^^^^^^^^^^^^^^^

Expand Down
17 changes: 17 additions & 0 deletions tests/conf/agrifieldnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 12
num_classes: 14
num_filters: 1
ignore_index: 0
data:
class_path: AgriFieldNetDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
paths: "tests/data/agrifieldnet"
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TestSemanticSegmentationTask:
@pytest.mark.parametrize(
"name",
[
"agrifieldnet",
"chabud",
"chesapeake_cvpr_5",
"chesapeake_cvpr_7",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""TorchGeo datamodules."""

from .agrifieldnet import AgriFieldNetDataModule
from .bigearthnet import BigEarthNetDataModule
from .chabud import ChaBuDDataModule
from .chesapeake import ChesapeakeCVPRDataModule
Expand Down Expand Up @@ -45,6 +46,7 @@

__all__ = (
# GeoDataset
"AgriFieldNetDataModule",
"ChesapeakeCVPRDataModule",
"L7IrishDataModule",
"L8BiomeDataModule",
Expand Down
86 changes: 86 additions & 0 deletions torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""AgriFieldNet datamodule."""

from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample

from ..datasets import AgriFieldNet, random_bbox_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class AgriFieldNetDataModule(GeoDataModule):
"""LightningDataModule implementation for the AgriFieldNet dataset.
.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 256,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new AgriFieldNetDataModule instance.
Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.AgriFieldNet`.
"""
super().__init__(
AgriFieldNet,
batch_size=batch_size,
patch_size=patch_size,
length=length,
num_workers=num_workers,
**kwargs,
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
dataset = AgriFieldNet(**self.kwargs)
generator = torch.Generator().manual_seed(0)
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_bbox_assignment(dataset, [0.8, 0.1, 0.1], generator)
)

if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)

0 comments on commit d030044

Please sign in to comment.