-
Notifications
You must be signed in to change notification settings - Fork 385
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add agrifieldnet datamodule * fix codecov * extra_args not needed * Bigger default batch size * Revert "extra_args not needed" This reverts commit f690d8b. * Same split as everyone else --------- Co-authored-by: Adam J. Stewart <[email protected]>
- Loading branch information
1 parent
bd48efe
commit d030044
Showing
5 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
model: | ||
class_path: SemanticSegmentationTask | ||
init_args: | ||
loss: "ce" | ||
model: "unet" | ||
backbone: "resnet18" | ||
in_channels: 12 | ||
num_classes: 14 | ||
num_filters: 1 | ||
ignore_index: 0 | ||
data: | ||
class_path: AgriFieldNetDataModule | ||
init_args: | ||
batch_size: 2 | ||
patch_size: 16 | ||
dict_kwargs: | ||
paths: "tests/data/agrifieldnet" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""AgriFieldNet datamodule.""" | ||
|
||
from typing import Any, Optional, Union | ||
|
||
import kornia.augmentation as K | ||
import torch | ||
from kornia.constants import DataKey, Resample | ||
|
||
from ..datasets import AgriFieldNet, random_bbox_assignment | ||
from ..samplers import GridGeoSampler, RandomBatchGeoSampler | ||
from ..samplers.utils import _to_tuple | ||
from ..transforms import AugmentationSequential | ||
from .geo import GeoDataModule | ||
|
||
|
||
class AgriFieldNetDataModule(GeoDataModule): | ||
"""LightningDataModule implementation for the AgriFieldNet dataset. | ||
.. versionadded:: 0.6 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
batch_size: int = 64, | ||
patch_size: Union[int, tuple[int, int]] = 256, | ||
length: Optional[int] = None, | ||
num_workers: int = 0, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Initialize a new AgriFieldNetDataModule instance. | ||
Args: | ||
batch_size: Size of each mini-batch. | ||
patch_size: Size of each patch, either ``size`` or ``(height, width)``. | ||
length: Length of each training epoch. | ||
num_workers: Number of workers for parallel data loading. | ||
**kwargs: Additional keyword arguments passed to | ||
:class:`~torchgeo.datasets.AgriFieldNet`. | ||
""" | ||
super().__init__( | ||
AgriFieldNet, | ||
batch_size=batch_size, | ||
patch_size=patch_size, | ||
length=length, | ||
num_workers=num_workers, | ||
**kwargs, | ||
) | ||
|
||
self.train_aug = AugmentationSequential( | ||
K.Normalize(mean=self.mean, std=self.std), | ||
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), | ||
K.RandomVerticalFlip(p=0.5), | ||
K.RandomHorizontalFlip(p=0.5), | ||
data_keys=["image", "mask"], | ||
extra_args={ | ||
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} | ||
}, | ||
) | ||
|
||
def setup(self, stage: str) -> None: | ||
"""Set up datasets. | ||
Args: | ||
stage: Either 'fit', 'validate', 'test', or 'predict'. | ||
""" | ||
dataset = AgriFieldNet(**self.kwargs) | ||
generator = torch.Generator().manual_seed(0) | ||
(self.train_dataset, self.val_dataset, self.test_dataset) = ( | ||
random_bbox_assignment(dataset, [0.8, 0.1, 0.1], generator) | ||
) | ||
|
||
if stage in ["fit"]: | ||
self.train_batch_sampler = RandomBatchGeoSampler( | ||
self.train_dataset, self.patch_size, self.batch_size, self.length | ||
) | ||
if stage in ["fit", "validate"]: | ||
self.val_sampler = GridGeoSampler( | ||
self.val_dataset, self.patch_size, self.patch_size | ||
) | ||
if stage in ["test"]: | ||
self.test_sampler = GridGeoSampler( | ||
self.test_dataset, self.patch_size, self.patch_size | ||
) |