From 84469004aacbdb1f1b5289848c49ca020f6e2f54 Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Thu, 16 Dec 2021 00:57:36 -0800 Subject: [PATCH 01/14] Add pixel sampling mode --- torchgeo/samplers/batch.py | 7 ++++++- torchgeo/samplers/constants.py | 6 ++++++ torchgeo/samplers/single.py | 7 ++++++- torchgeo/samplers/utils.py | 18 +++++++++++++++--- 4 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 torchgeo/samplers/constants.py diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 0dcd1a85aef..6d12ab6d0d1 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -12,6 +12,7 @@ from torchgeo.datasets.geo import GeoDataset from torchgeo.datasets.utils import BoundingBox +from torchgeo.samplers.constants import SIZE_IN_CRS_UNITS from .utils import _to_tuple, get_random_bounding_box @@ -73,6 +74,7 @@ def __init__( batch_size: int, length: int, roi: Optional[BoundingBox] = None, + sample_mode: int = SIZE_IN_CRS_UNITS, ) -> None: """Initialize a new Sampler instance. @@ -95,6 +97,7 @@ def __init__( self.size = _to_tuple(size) self.batch_size = batch_size self.length = length + self.sample_mode = sample_mode self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) def __iter__(self) -> Iterator[List[BoundingBox]]: @@ -112,7 +115,9 @@ def __iter__(self) -> Iterator[List[BoundingBox]]: batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.sample_mode + ) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py new file mode 100644 index 00000000000..eb0d79ad0fe --- /dev/null +++ b/torchgeo/samplers/constants.py @@ -0,0 +1,6 @@ +from torch.utils.data import Sampler + +Sampler.__module__ = "torch.utils.data" + +SIZE_IN_PIXELS = 0 +SIZE_IN_CRS_UNITS = 1 diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 1804d9a2d84..16adebe2d14 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -13,6 +13,7 @@ from torchgeo.datasets.geo import GeoDataset from torchgeo.datasets.utils import BoundingBox +from .constants import SIZE_IN_CRS_UNITS from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -75,6 +76,7 @@ def __init__( size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, + sample_mode: int = SIZE_IN_CRS_UNITS, ) -> None: """Initialize a new Sampler instance. @@ -95,6 +97,7 @@ def __init__( super().__init__(dataset, roi) self.size = _to_tuple(size) self.length = length + self.sample_mode = sample_mode self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) def __iter__(self) -> Iterator[BoundingBox]: @@ -109,7 +112,9 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.sample_mode + ) yield bounding_box diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index b8aecd85a11..635d234ab42 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -8,6 +8,8 @@ from torchgeo.datasets.utils import BoundingBox +from .constants import SIZE_IN_PIXELS, SIZE_IN_CRS_UNITS + def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: """Convert value to a tuple if it is not already a tuple. @@ -25,7 +27,10 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float + bounds: BoundingBox, + size: Union[Tuple[float, float], float], + res: float, + sample_mode: int, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -39,6 +44,7 @@ def get_random_bounding_box( Args: bounds: the larger bounding box to sample from size: the size of the bounding box to sample + sample_mode: whether to sample in pixel space or CRS unit space Returns: randomly sampled bounding box from the extent of the input @@ -47,11 +53,17 @@ def get_random_bounding_box( width = (bounds.maxx - bounds.minx - t_size[1]) // res minx = random.randrange(int(width)) * res + bounds.minx - maxx = minx + t_size[1] + if sample_mode == SIZE_IN_CRS_UNITS: + maxx = minx + t_size[1] + elif sample_mode == SIZE_IN_PIXELS: + maxx = minx + t_size[1] * res height = (bounds.maxy - bounds.miny - t_size[0]) // res miny = random.randrange(int(height)) * res + bounds.miny - maxy = miny + t_size[0] + if sample_mode == SIZE_IN_CRS_UNITS: + maxy = miny + t_size[0] + elif sample_mode == SIZE_IN_PIXELS: + maxy = miny + t_size[1] * res mint = bounds.mint maxt = bounds.maxt From 980ce750b15394dd4bf1b5b3dbbb9cd3c985cd6a Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Thu, 16 Dec 2021 01:45:48 -0800 Subject: [PATCH 02/14] Fix maxy indexing error Co-authored-by: Ashwin Nair --- torchgeo/samplers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 635d234ab42..9fe958725e3 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -63,7 +63,7 @@ def get_random_bounding_box( if sample_mode == SIZE_IN_CRS_UNITS: maxy = miny + t_size[0] elif sample_mode == SIZE_IN_PIXELS: - maxy = miny + t_size[1] * res + maxy = miny + t_size[0] * res mint = bounds.mint maxt = bounds.maxt From 463572ecdecfd977eec1c185c542d6d2df71be4c Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Tue, 8 Feb 2022 13:34:00 -0800 Subject: [PATCH 03/14] Add sample_mode docstrings, default to PIXELS --- torchgeo/samplers/batch.py | 5 +++-- torchgeo/samplers/constants.py | 4 ---- torchgeo/samplers/single.py | 5 +++-- torchgeo/samplers/utils.py | 9 ++++----- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 6d12ab6d0d1..a7bfad38e4a 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -12,7 +12,7 @@ from torchgeo.datasets.geo import GeoDataset from torchgeo.datasets.utils import BoundingBox -from torchgeo.samplers.constants import SIZE_IN_CRS_UNITS +from torchgeo.samplers.constants import SIZE_IN_CRS_UNITS, SIZE_IN_PIXELS from .utils import _to_tuple, get_random_bounding_box @@ -74,7 +74,7 @@ def __init__( batch_size: int, length: int, roi: Optional[BoundingBox] = None, - sample_mode: int = SIZE_IN_CRS_UNITS, + sample_mode: int = SIZE_IN_PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -92,6 +92,7 @@ def __init__( length: number of samples per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) + sample_mode: defines if `size` is in pixels or in CRS units. """ super().__init__(dataset, roi) self.size = _to_tuple(size) diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py index eb0d79ad0fe..2810e3b63a3 100644 --- a/torchgeo/samplers/constants.py +++ b/torchgeo/samplers/constants.py @@ -1,6 +1,2 @@ -from torch.utils.data import Sampler - -Sampler.__module__ = "torch.utils.data" - SIZE_IN_PIXELS = 0 SIZE_IN_CRS_UNITS = 1 diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 16adebe2d14..1c930c62511 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -13,7 +13,7 @@ from torchgeo.datasets.geo import GeoDataset from torchgeo.datasets.utils import BoundingBox -from .constants import SIZE_IN_CRS_UNITS +from .constants import SIZE_IN_CRS_UNITS, SIZE_IN_PIXELS from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -76,7 +76,7 @@ def __init__( size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, - sample_mode: int = SIZE_IN_CRS_UNITS, + sample_mode: int = SIZE_IN_PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -93,6 +93,7 @@ def __init__( length: number of random samples to draw per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) + sample_mode: defines if `size` is in pixels or in CRS units. """ super().__init__(dataset, roi) self.size = _to_tuple(size) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 635d234ab42..3251eefdb8f 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -44,26 +44,25 @@ def get_random_bounding_box( Args: bounds: the larger bounding box to sample from size: the size of the bounding box to sample - sample_mode: whether to sample in pixel space or CRS unit space + sample_mode: defines if `size` is in pixels or in CRS units. Returns: randomly sampled bounding box from the extent of the input """ t_size: Tuple[float, float] = _to_tuple(size) + if sample_mode == SIZE_IN_PIXELS: + t_size[0] *= res + t_size[1] *= res width = (bounds.maxx - bounds.minx - t_size[1]) // res minx = random.randrange(int(width)) * res + bounds.minx if sample_mode == SIZE_IN_CRS_UNITS: maxx = minx + t_size[1] - elif sample_mode == SIZE_IN_PIXELS: - maxx = minx + t_size[1] * res height = (bounds.maxy - bounds.miny - t_size[0]) // res miny = random.randrange(int(height)) * res + bounds.miny if sample_mode == SIZE_IN_CRS_UNITS: maxy = miny + t_size[0] - elif sample_mode == SIZE_IN_PIXELS: - maxy = miny + t_size[1] * res mint = bounds.mint maxt = bounds.maxt From 12953f7515b198babb9616beb207253db88b2011 Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Tue, 8 Feb 2022 13:42:21 -0800 Subject: [PATCH 04/14] Replace sample_mode with units --- torchgeo/samplers/batch.py | 8 ++++---- torchgeo/samplers/single.py | 8 ++++---- torchgeo/samplers/utils.py | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index a7bfad38e4a..37b83e7ffe8 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -74,7 +74,7 @@ def __init__( batch_size: int, length: int, roi: Optional[BoundingBox] = None, - sample_mode: int = SIZE_IN_PIXELS, + units: int = SIZE_IN_PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -92,13 +92,13 @@ def __init__( length: number of samples per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) - sample_mode: defines if `size` is in pixels or in CRS units. + units: defines if `size` is in pixels or in CRS units. """ super().__init__(dataset, roi) self.size = _to_tuple(size) self.batch_size = batch_size self.length = length - self.sample_mode = sample_mode + self.units = units self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) def __iter__(self) -> Iterator[List[BoundingBox]]: @@ -117,7 +117,7 @@ def __iter__(self) -> Iterator[List[BoundingBox]]: for _ in range(self.batch_size): bounding_box = get_random_bounding_box( - bounds, self.size, self.res, self.sample_mode + bounds, self.size, self.res, self.units ) batch.append(bounding_box) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 1c930c62511..97ad89e772b 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -76,7 +76,7 @@ def __init__( size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, - sample_mode: int = SIZE_IN_PIXELS, + units: int = SIZE_IN_PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -93,12 +93,12 @@ def __init__( length: number of random samples to draw per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) - sample_mode: defines if `size` is in pixels or in CRS units. + units: defines if `size` is in pixels or in CRS units. """ super().__init__(dataset, roi) self.size = _to_tuple(size) self.length = length - self.sample_mode = sample_mode + self.units = units self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) def __iter__(self) -> Iterator[BoundingBox]: @@ -114,7 +114,7 @@ def __iter__(self) -> Iterator[BoundingBox]: # Choose a random index within that tile bounding_box = get_random_bounding_box( - bounds, self.size, self.res, self.sample_mode + bounds, self.size, self.res, self.units ) yield bounding_box diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 3251eefdb8f..f2fa30c9383 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -30,7 +30,7 @@ def get_random_bounding_box( bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float, - sample_mode: int, + units: int, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -44,24 +44,24 @@ def get_random_bounding_box( Args: bounds: the larger bounding box to sample from size: the size of the bounding box to sample - sample_mode: defines if `size` is in pixels or in CRS units. + units: defines if `size` is in pixels or in CRS units. Returns: randomly sampled bounding box from the extent of the input """ t_size: Tuple[float, float] = _to_tuple(size) - if sample_mode == SIZE_IN_PIXELS: + if units == SIZE_IN_PIXELS: t_size[0] *= res t_size[1] *= res width = (bounds.maxx - bounds.minx - t_size[1]) // res minx = random.randrange(int(width)) * res + bounds.minx - if sample_mode == SIZE_IN_CRS_UNITS: + if units == SIZE_IN_CRS_UNITS: maxx = minx + t_size[1] height = (bounds.maxy - bounds.miny - t_size[0]) // res miny = random.randrange(int(height)) * res + bounds.miny - if sample_mode == SIZE_IN_CRS_UNITS: + if units == SIZE_IN_CRS_UNITS: maxy = miny + t_size[0] mint = bounds.mint From 2131c68cb0f9460be3dcb80a61aabf482ea3b2e5 Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Wed, 9 Feb 2022 20:30:50 -0800 Subject: [PATCH 05/14] Update to use enum --- torchgeo/samplers/batch.py | 4 ++-- torchgeo/samplers/constants.py | 10 ++++++++-- torchgeo/samplers/single.py | 4 ++-- torchgeo/samplers/utils.py | 15 +++++---------- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 37b83e7ffe8..8fa2b8c0ae9 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -12,7 +12,7 @@ from torchgeo.datasets.geo import GeoDataset from torchgeo.datasets.utils import BoundingBox -from torchgeo.samplers.constants import SIZE_IN_CRS_UNITS, SIZE_IN_PIXELS +from torchgeo.samplers.constants import Units from .utils import _to_tuple, get_random_bounding_box @@ -74,7 +74,7 @@ def __init__( batch_size: int, length: int, roi: Optional[BoundingBox] = None, - units: int = SIZE_IN_PIXELS, + units: int = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py index 2810e3b63a3..498eb60d372 100644 --- a/torchgeo/samplers/constants.py +++ b/torchgeo/samplers/constants.py @@ -1,2 +1,8 @@ -SIZE_IN_PIXELS = 0 -SIZE_IN_CRS_UNITS = 1 +from enum import Enum + + +class Units(Enum): + """Enumeration to define units of `size` used for GeoSampler""" + + PIXELS = 0 + CRS = 1 diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 97ad89e772b..aee32e90572 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -12,8 +12,8 @@ from torchgeo.datasets.geo import GeoDataset from torchgeo.datasets.utils import BoundingBox +from torchgeo.samplers.constants import Units -from .constants import SIZE_IN_CRS_UNITS, SIZE_IN_PIXELS from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -76,7 +76,7 @@ def __init__( size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, - units: int = SIZE_IN_PIXELS, + units: int = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index f2fa30c9383..016808bb09c 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -8,7 +8,7 @@ from torchgeo.datasets.utils import BoundingBox -from .constants import SIZE_IN_PIXELS, SIZE_IN_CRS_UNITS +from .constants import Units def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: @@ -27,10 +27,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, - size: Union[Tuple[float, float], float], - res: float, - units: int, + bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float, units: int ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -50,19 +47,17 @@ def get_random_bounding_box( randomly sampled bounding box from the extent of the input """ t_size: Tuple[float, float] = _to_tuple(size) - if units == SIZE_IN_PIXELS: + if units == Units.PIXELS: t_size[0] *= res t_size[1] *= res width = (bounds.maxx - bounds.minx - t_size[1]) // res minx = random.randrange(int(width)) * res + bounds.minx - if units == SIZE_IN_CRS_UNITS: - maxx = minx + t_size[1] + maxx = minx + t_size[1] height = (bounds.maxy - bounds.miny - t_size[0]) // res miny = random.randrange(int(height)) * res + bounds.miny - if units == SIZE_IN_CRS_UNITS: - maxy = miny + t_size[0] + maxy = miny + t_size[0] mint = bounds.mint maxt = bounds.maxt From 784a0ff4e8c927e4d6e10aaa904f293724554883 Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Wed, 9 Feb 2022 20:48:29 -0800 Subject: [PATCH 06/14] Fix mypy, tuple, and flake8 issues --- torchgeo/samplers/batch.py | 3 +-- torchgeo/samplers/single.py | 2 +- torchgeo/samplers/utils.py | 9 ++++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 1325dcd11e1..a3a055bc55f 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -14,7 +14,6 @@ from torchgeo.datasets.utils import BoundingBox from torchgeo.samplers.constants import Units -from ..datasets import BoundingBox, GeoDataset from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -75,7 +74,7 @@ def __init__( batch_size: int, length: int, roi: Optional[BoundingBox] = None, - units: int = Units.PIXELS, + units: Union[Units, int] = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index a0c2469db57..6815351a718 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -76,7 +76,7 @@ def __init__( size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, - units: int = Units.PIXELS, + units: Union[Units, int] = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 49b159695b2..d327e286f15 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -27,7 +27,10 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float, units: int + bounds: BoundingBox, + size: Union[Tuple[float, float], float], + res: float, + units: Union[Units, int], ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -48,8 +51,8 @@ def get_random_bounding_box( """ t_size: Tuple[float, float] = _to_tuple(size) if units == Units.PIXELS: - t_size[0] *= res - t_size[1] *= res + # We have to re-assign t_size because tuples are immutable + t_size = (t_size[0] * res, t_size[1] * res) width = (bounds.maxx - bounds.minx - t_size[1]) // res minx = random.randrange(int(width)) * res + bounds.minx From 0d9b457086790944b76cdea63ccd7ca18350dbdf Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Wed, 9 Feb 2022 20:55:00 -0800 Subject: [PATCH 07/14] Fix isort and pydocstyle problems --- torchgeo/samplers/constants.py | 7 ++++++- torchgeo/samplers/utils.py | 5 ++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py index 498eb60d372..c47400d3bcd 100644 --- a/torchgeo/samplers/constants.py +++ b/torchgeo/samplers/constants.py @@ -1,8 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Common sampler constants.""" + from enum import Enum class Units(Enum): - """Enumeration to define units of `size` used for GeoSampler""" + """Enumeration to define units of `size` used for GeoSampler.""" PIXELS = 0 CRS = 1 diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index d327e286f15..96a606aa250 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -6,9 +6,8 @@ import random from typing import Tuple, Union -from ..datasets import BoundingBox - -from .constants import Units +from torchgeo.datasets.utils import BoundingBox +from torchgeo.samplers.constants import Units def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: From ab0ebc9fd67a2099d43fe766def3d09790a41106 Mon Sep 17 00:00:00 2001 From: Ritwik Gupta Date: Tue, 15 Feb 2022 22:26:40 -0800 Subject: [PATCH 08/14] Update sampler docs to discuss unit sampling mode --- docs/api/samplers.rst | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/api/samplers.rst b/docs/api/samplers.rst index fa1ca65227e..29afd2e230e 100644 --- a/docs/api/samplers.rst +++ b/docs/api/samplers.rst @@ -3,6 +3,14 @@ torchgeo.samplers .. module:: torchgeo.samplers +Constants and Enums +-------- + +When sampling imagery, the Samplers below can sample tiles in one of two units: 1) the units specified by the CRS of the rasters, or 2) pixels. + +The Units enum within :module:`~torchgeo.samplers.constants` specifies these two sampling modes. By default, the Samplers use Units.PIXELS. + + Samplers -------- @@ -14,9 +22,10 @@ Samplers are used to index a dataset, retrieving a single query at a time. For : from torchgeo.datasets import Landsat from torchgeo.samplers import RandomGeoSampler + from torchgeo.samplers.constants import Units dataset = Landsat(...) - sampler = RandomGeoSampler(dataset, size=1000, length=100) + sampler = RandomGeoSampler(dataset, size=1000, length=100, units=Units.PIXELS) dataloader = DataLoader(dataset, sampler=sampler) @@ -41,9 +50,10 @@ When working with large tile-based datasets, randomly sampling patches from each from torchgeo.datasets import Landsat from torchgeo.samplers import RandomBatchGeoSampler + from torchgeo.samplers.constants import Units dataset = Landsat(...) - sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100) + sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100, units=Units.PIXELS) dataloader = DataLoader(dataset, batch_sampler=sampler) From a1b994cd94254b726ee5d1a3f21935faea52e2c4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 22 Feb 2022 15:54:01 -0600 Subject: [PATCH 09/14] Various fixes --- docs/api/samplers.rst | 39 ++++++++++++++++++++++++----------- torchgeo/samplers/__init__.py | 3 +++ torchgeo/samplers/batch.py | 12 +++++------ torchgeo/samplers/single.py | 12 +++++------ torchgeo/samplers/utils.py | 6 +++--- 5 files changed, 43 insertions(+), 29 deletions(-) diff --git a/docs/api/samplers.rst b/docs/api/samplers.rst index 29afd2e230e..be7fbefb910 100644 --- a/docs/api/samplers.rst +++ b/docs/api/samplers.rst @@ -3,14 +3,6 @@ torchgeo.samplers .. module:: torchgeo.samplers -Constants and Enums --------- - -When sampling imagery, the Samplers below can sample tiles in one of two units: 1) the units specified by the CRS of the rasters, or 2) pixels. - -The Units enum within :module:`~torchgeo.samplers.constants` specifies these two sampling modes. By default, the Samplers use Units.PIXELS. - - Samplers -------- @@ -22,13 +14,14 @@ Samplers are used to index a dataset, retrieving a single query at a time. For : from torchgeo.datasets import Landsat from torchgeo.samplers import RandomGeoSampler - from torchgeo.samplers.constants import Units dataset = Landsat(...) - sampler = RandomGeoSampler(dataset, size=1000, length=100, units=Units.PIXELS) + sampler = RandomGeoSampler(dataset, size=256, length=10000) dataloader = DataLoader(dataset, sampler=sampler) +This data loader will return 256x256 px images, and has an epoch length of 10,000. + Random Geo Sampler ^^^^^^^^^^^^^^^^^^ @@ -50,13 +43,14 @@ When working with large tile-based datasets, randomly sampling patches from each from torchgeo.datasets import Landsat from torchgeo.samplers import RandomBatchGeoSampler - from torchgeo.samplers.constants import Units dataset = Landsat(...) - sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100, units=Units.PIXELS) + sampler = RandomBatchGeoSampler(dataset, size=256, batch_size=128, length=10000) dataloader = DataLoader(dataset, batch_sampler=sampler) +This data loader will return 256x256 px images, and has a batch size of 128 and an epoch length of 10,000. + Random Batch Geo Sampler ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -76,3 +70,24 @@ Batch Geo Sampler ^^^^^^^^^^^^^^^^^ .. autoclass:: BatchGeoSampler + +Units +----- + +By default, the ``size`` parameter specifies the size of the image in *pixel* units. If you would instead like to specify the size in *CRS* units, you can change the ``units`` parameter like so: + +.. code-block:: python + + from torch.utils.data import DataLoader + + from torchgeo.datasets import Landsat + from torchgeo.samplers import RandomGeoSampler, Units + + dataset = Landsat(...) + sampler = RandomGeoSampler(dataset, size=256 * 30, length=10000, units=Units.CRS) + dataloader = DataLoader(dataset, sampler=sampler) + + +Assuming that each pixel in the CRS is 30 m, this data loader will return 256x256 px images, and has an epoch length of 10,000. + +.. autoclass:: Units diff --git a/torchgeo/samplers/__init__.py b/torchgeo/samplers/__init__.py index da02f09d802..a6f63de1917 100644 --- a/torchgeo/samplers/__init__.py +++ b/torchgeo/samplers/__init__.py @@ -4,6 +4,7 @@ """TorchGeo samplers.""" from .batch import BatchGeoSampler, RandomBatchGeoSampler +from .constants import Units from .single import GeoSampler, GridGeoSampler, RandomGeoSampler __all__ = ( @@ -15,6 +16,8 @@ # Base classes "GeoSampler", "BatchGeoSampler", + # Constants + "Units", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index a3a055bc55f..0dbc1f9a305 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -10,10 +10,8 @@ from rtree.index import Index, Property from torch.utils.data import Sampler -from torchgeo.datasets.geo import GeoDataset -from torchgeo.datasets.utils import BoundingBox -from torchgeo.samplers.constants import Units - +from ..datasets import BoundingBox, GeoDataset +from .constants import Units from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -74,7 +72,7 @@ def __init__( batch_size: int, length: int, roi: Optional[BoundingBox] = None, - units: Union[Units, int] = Units.PIXELS, + units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -87,12 +85,12 @@ def __init__( Args: dataset: dataset to index from - size: dimensions of each :term:`patch` in units of CRS + size: dimensions of each :term:`patch` batch_size: number of samples per batch length: number of samples per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) - units: defines if `size` is in pixels or in CRS units. + units: defines if ``size`` is in pixel or CRS units """ super().__init__(dataset, roi) self.size = _to_tuple(size) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 6815351a718..34e3837d47c 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -10,10 +10,8 @@ from rtree.index import Index, Property from torch.utils.data import Sampler -from torchgeo.datasets.geo import GeoDataset -from torchgeo.datasets.utils import BoundingBox -from torchgeo.samplers.constants import Units - +from ..datasets import BoundingBox, GeoDataset +from .constants import Units from .utils import _to_tuple, get_random_bounding_box # https://github.com/pytorch/pytorch/issues/60979 @@ -76,7 +74,7 @@ def __init__( size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, - units: Union[Units, int] = Units.PIXELS, + units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -89,11 +87,11 @@ def __init__( Args: dataset: dataset to index from - size: dimensions of each :term:`patch` in units of CRS + size: dimensions of each :term:`patch` length: number of random samples to draw per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) - units: defines if `size` is in pixels or in CRS units. + units: defines if ``size`` is in pixel or CRS units """ super().__init__(dataset, roi) self.size = _to_tuple(size) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 96a606aa250..6f06d59f6f1 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -6,8 +6,8 @@ import random from typing import Tuple, Union -from torchgeo.datasets.utils import BoundingBox -from torchgeo.samplers.constants import Units +from ..datasets import BoundingBox +from .constants import Units def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: @@ -29,7 +29,7 @@ def get_random_bounding_box( bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float, - units: Union[Units, int], + units: Units, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. From 518bd1ea2990b5e8ae143b95f48cb2fb0177cf85 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 22 Feb 2022 16:17:03 -0600 Subject: [PATCH 10/14] Add units arg to GridGeoSampler --- torchgeo/samplers/batch.py | 12 ++++++++---- torchgeo/samplers/constants.py | 12 ++++++++---- torchgeo/samplers/single.py | 24 +++++++++++++++++++----- torchgeo/samplers/utils.py | 11 ++--------- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 0dbc1f9a305..e269a748db6 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -91,12 +91,18 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + + .. versionchanged:: 0.3 + Added ``units`` parameter, changed default to pixel units """ super().__init__(dataset, roi) self.size = _to_tuple(size) + + if units == Units.PIXELS: + self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.batch_size = batch_size self.length = length - self.units = units self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) @@ -121,9 +127,7 @@ def __iter__(self) -> Iterator[List[BoundingBox]]: batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box( - bounds, self.size, self.res, self.units - ) + bounding_box = get_random_bounding_box(bounds, self.size, self.res) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py index c47400d3bcd..4483a869861 100644 --- a/torchgeo/samplers/constants.py +++ b/torchgeo/samplers/constants.py @@ -3,11 +3,15 @@ """Common sampler constants.""" -from enum import Enum +from enum import Enum, auto class Units(Enum): - """Enumeration to define units of `size` used for GeoSampler.""" + """Enumeration defining units of ``size`` parameter. - PIXELS = 0 - CRS = 1 + Used by :class:`~torchgeo.sampler.GeoSampler` and + :class:`~torchgeo.sampler.BatchGeoSampler`. + """ + + PIXELS = auto() + CRS = auto() diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 34e3837d47c..781cb5c38b1 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -92,11 +92,17 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + + .. versionchanged:: 0.3 + Added ``units`` parameter, changed default to pixel units """ super().__init__(dataset, roi) self.size = _to_tuple(size) + + if units == Units.PIXELS: + self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.length = length - self.units = units self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) @@ -118,9 +124,7 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box( - bounds, self.size, self.res, self.units - ) + bounding_box = get_random_bounding_box(bounds, self.size, self.res) yield bounding_box @@ -154,6 +158,7 @@ def __init__( size: Union[Tuple[float, float], float], stride: Union[Tuple[float, float], float], roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -166,14 +171,23 @@ def __init__( Args: dataset: dataset to index from - size: dimensions of each :term:`patch` in units of CRS + size: dimensions of each :term:`patch` stride: distance to skip between each patch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) + units: defines if ``size`` and ``stride`` are in pixel or CRS units + + .. versionchanged:: 0.3 + Added ``units`` parameter, changed default to pixel units """ super().__init__(dataset, roi) self.size = _to_tuple(size) self.stride = _to_tuple(stride) + + if units == Units.PIXELS: + self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.stride = (self.stride[0] * self.res, self.stride[1] * self.res) + self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 6f06d59f6f1..983d604a4b6 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -26,10 +26,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, - size: Union[Tuple[float, float], float], - res: float, - units: Units, + bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -43,15 +40,11 @@ def get_random_bounding_box( Args: bounds: the larger bounding box to sample from size: the size of the bounding box to sample - units: defines if `size` is in pixels or in CRS units. Returns: randomly sampled bounding box from the extent of the input """ - t_size: Tuple[float, float] = _to_tuple(size) - if units == Units.PIXELS: - # We have to re-assign t_size because tuples are immutable - t_size = (t_size[0] * res, t_size[1] * res) + t_size = _to_tuple(size) width = (bounds.maxx - bounds.minx - t_size[1]) // res minx = random.randrange(int(width)) * res + bounds.minx From cb72d0a60e0e993ade70a34dc54db7c8e704906d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 22 Feb 2022 16:24:10 -0600 Subject: [PATCH 11/14] Update benchmark script --- benchmark.py | 17 +++++++---------- torchgeo/samplers/constants.py | 4 ++-- torchgeo/samplers/utils.py | 1 - 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/benchmark.py b/benchmark.py index d0d499e5fd5..2b3465bc3a9 100755 --- a/benchmark.py +++ b/benchmark.py @@ -77,15 +77,16 @@ def set_up_parser() -> argparse.ArgumentParser: "--patch-size", default=224, type=int, - help="height/width of each patch", - metavar="SIZE", + help="height/width of each patch in pixels", + metavar="PIXELS", ) parser.add_argument( "-s", "--stride", default=112, type=int, - help="sampling stride for GridGeoSampler", + help="sampling stride for GridGeoSampler in pixels", + metavar="PIXELS", ) parser.add_argument( "-w", @@ -139,15 +140,11 @@ def main(args: argparse.Namespace) -> None: length = args.num_batches * args.batch_size num_batches = args.num_batches - # Convert from pixel coords to CRS coords - size = args.patch_size * cdl.res - stride = args.stride * cdl.res - samplers = [ - RandomGeoSampler(landsat, size=size, length=length), - GridGeoSampler(landsat, size=size, stride=stride), + RandomGeoSampler(landsat, size=args.patch_size, length=length), + GridGeoSampler(landsat, size=args.patch_size, stride=args.stride), RandomBatchGeoSampler( - landsat, size=size, batch_size=args.batch_size, length=length + landsat, size=args.patch_size, batch_size=args.batch_size, length=length ), ] diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py index 4483a869861..203a72481a3 100644 --- a/torchgeo/samplers/constants.py +++ b/torchgeo/samplers/constants.py @@ -9,8 +9,8 @@ class Units(Enum): """Enumeration defining units of ``size`` parameter. - Used by :class:`~torchgeo.sampler.GeoSampler` and - :class:`~torchgeo.sampler.BatchGeoSampler`. + Used by :class:`~torchgeo.samplers.GeoSampler` and + :class:`~torchgeo.samplers.BatchGeoSampler`. """ PIXELS = auto() diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 983d604a4b6..f8382626ee8 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -7,7 +7,6 @@ from typing import Tuple, Union from ..datasets import BoundingBox -from .constants import Units def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: From d0f25024699017323c2aa6b551a62577c53b23e0 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 22 Feb 2022 21:08:40 -0600 Subject: [PATCH 12/14] Add tests --- tests/samplers/test_batch.py | 74 +++++++++++++++++-------- tests/samplers/test_single.py | 101 ++++++++++++++++++++-------------- 2 files changed, 109 insertions(+), 66 deletions(-) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 5c114bb86d7..80edf43845e 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import math +from itertools import product from typing import Dict, Iterator, List import pytest @@ -10,7 +11,7 @@ from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples -from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler +from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units class CustomBatchGeoSampler(BatchGeoSampler): @@ -26,7 +27,7 @@ def __len__(self) -> int: class CustomGeoDataset(GeoDataset): - def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None: + def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: super().__init__() self._crs = crs self.res = res @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: class TestBatchGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + return CustomGeoDataset() + @pytest.fixture(scope="function") def sampler(self) -> CustomBatchGeoSampler: return CustomBatchGeoSampler() @@ -49,28 +54,45 @@ def test_len(self, sampler: CustomBatchGeoSampler) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: CustomBatchGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, + dataset: CustomGeoDataset, + sampler: CustomBatchGeoSampler, + num_workers: int, + ) -> None: dl = DataLoader( - ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, + batch_sampler=sampler, + num_workers=num_workers, + collate_fn=stack_samples, ) for _ in dl: continue - def test_abstract(self) -> None: - ds = CustomGeoDataset() + def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): - BatchGeoSampler(ds) # type: ignore[abstract] + BatchGeoSampler(dataset) # type: ignore[abstract] class TestRandomBatchGeoSampler: - @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) - def sampler(self, request: SubRequest) -> RandomBatchGeoSampler: + @pytest.fixture(scope="class") + def dataset(self): ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 20, 30, 40, 50)) - ds.index.insert(1, (0, 10, 20, 30, 40, 50)) - size = request.param - return RandomBatchGeoSampler(ds, size, batch_size=2, length=10) + ds.index.insert(0, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + return ds + + @pytest.fixture( + scope="function", + params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]), + ) + def sampler( + self, dataset: CustomGeoDataset, request: SubRequest + ) -> RandomBatchGeoSampler: + size, units = request.param + return RandomBatchGeoSampler( + dataset, size, batch_size=2, length=10, units=units + ) def test_iter(self, sampler: RandomBatchGeoSampler) -> None: for batch in sampler: @@ -88,18 +110,15 @@ def test_iter(self, sampler: RandomBatchGeoSampler) -> None: def test_len(self, sampler: RandomBatchGeoSampler) -> None: assert len(sampler) == sampler.length // sampler.batch_size - def test_roi(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (5, 15, 5, 15, 5, 15)) - roi = BoundingBox(0, 10, 0, 10, 0, 10) - sampler = RandomBatchGeoSampler(ds, 2, 2, 10, roi=roi) + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(0, 50, 200, 250, 400, 450) + sampler = RandomBatchGeoSampler(dataset, 2, 2, 10, roi=roi) for batch in sampler: for query in batch: assert query in roi def test_small_area(self) -> None: - ds = CustomGeoDataset() + ds = CustomGeoDataset(res=1) ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (20, 21, 20, 21, 20, 21)) sampler = RandomBatchGeoSampler(ds, 2, 2, 10) @@ -108,10 +127,17 @@ def test_small_area(self) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, + dataset: CustomGeoDataset, + sampler: RandomBatchGeoSampler, + num_workers: int, + ) -> None: dl = DataLoader( - ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, + batch_sampler=sampler, + num_workers=num_workers, + collate_fn=stack_samples, ) for _ in dl: continue diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index aa13b8b56aa..9dd49f24c79 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import math +from itertools import product from typing import Dict, Iterator import pytest @@ -10,7 +11,7 @@ from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples -from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler +from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units class CustomGeoSampler(GeoSampler): @@ -26,7 +27,7 @@ def __len__(self) -> int: class CustomGeoDataset(GeoDataset): - def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None: + def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: super().__init__() self._crs = crs self.res = res @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]: class TestGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + return CustomGeoDataset() + @pytest.fixture(scope="function") def sampler(self) -> CustomGeoSampler: return CustomGeoSampler() @@ -46,30 +51,39 @@ def test_iter(self, sampler: CustomGeoSampler) -> None: def test_len(self, sampler: CustomGeoSampler) -> None: assert len(sampler) == 2 - def test_abstract(self) -> None: - ds = CustomGeoDataset() + def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): - GeoSampler(ds) # type: ignore[abstract] + GeoSampler(dataset) # type: ignore[abstract] @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: CustomGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: CustomGeoSampler, num_workers: int + ) -> None: dl = DataLoader( - ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples ) for _ in dl: continue class TestRandomGeoSampler: - @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) - def sampler(self, request: SubRequest) -> RandomGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 20, 30, 40, 50)) - ds.index.insert(1, (0, 10, 20, 30, 40, 50)) - size = request.param - return RandomGeoSampler(ds, size, length=10) + ds.index.insert(0, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + return ds + + @pytest.fixture( + scope="function", + params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]), + ) + def sampler( + self, dataset: CustomGeoDataset, request: SubRequest + ) -> RandomGeoSampler: + size, units = request.param + return RandomGeoSampler(dataset, size, length=10, units=units) def test_iter(self, sampler: RandomGeoSampler) -> None: for query in sampler: @@ -86,17 +100,14 @@ def test_iter(self, sampler: RandomGeoSampler) -> None: def test_len(self, sampler: RandomGeoSampler) -> None: assert len(sampler) == sampler.length - def test_roi(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (5, 15, 5, 15, 5, 15)) - roi = BoundingBox(0, 10, 0, 10, 0, 10) - sampler = RandomGeoSampler(ds, 2, 10, roi=roi) + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(0, 50, 200, 250, 400, 450) + sampler = RandomGeoSampler(dataset, 2, 10, roi=roi) for query in sampler: assert query in roi def test_small_area(self) -> None: - ds = CustomGeoDataset() + ds = CustomGeoDataset(res=1) ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (20, 21, 20, 21, 20, 21)) sampler = RandomGeoSampler(ds, 2, 10) @@ -105,26 +116,34 @@ def test_small_area(self) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: RandomGeoSampler, num_workers: int + ) -> None: dl = DataLoader( - ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples ) for _ in dl: continue class TestGridGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + return ds + @pytest.fixture( scope="function", - params=[(8, 1), (6, 2), (4, 3), (2.5, 3), ((8, 6), (1, 2)), ((6, 4), (2, 3))], + params=product( + [(8, 1), (6, 2), (4, 3), (2.5, 3), ((8, 6), (1, 2)), ((6, 4), (2, 3))], + [Units.PIXELS, Units.CRS], + ), ) - def sampler(self, request: SubRequest) -> GridGeoSampler: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 20, 0, 10, 40, 50)) - ds.index.insert(1, (0, 20, 0, 10, 40, 50)) - size, stride = request.param - return GridGeoSampler(ds, size, stride) + def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSampler: + (size, stride), units = request.param + return GridGeoSampler(dataset, size, stride, units=units) def test_iter(self, sampler: GridGeoSampler) -> None: for query in sampler: @@ -139,17 +158,14 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = int((10 - sampler.size[0]) // sampler.stride[0]) + 1 - cols = int((20 - sampler.size[1]) // sampler.stride[1]) + 1 + rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1 + cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1 length = rows * cols * 2 assert len(sampler) == length - def test_roi(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (5, 15, 5, 15, 5, 15)) - roi = BoundingBox(0, 10, 0, 10, 0, 10) - sampler = GridGeoSampler(ds, 2, 1, roi=roi) + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(0, 50, 200, 250, 400, 450) + sampler = GridGeoSampler(dataset, 2, 1, roi=roi) for query in sampler: assert query in roi @@ -163,10 +179,11 @@ def test_small_area(self) -> None: @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) - def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None: - ds = CustomGeoDataset() + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: GridGeoSampler, num_workers: int + ) -> None: dl = DataLoader( - ds, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples ) for _ in dl: continue From 447652a9865fe2a3f3bf2f957e2abc57ff7f2079 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 22 Feb 2022 21:16:59 -0600 Subject: [PATCH 13/14] Document enum values --- torchgeo/samplers/constants.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchgeo/samplers/constants.py b/torchgeo/samplers/constants.py index 203a72481a3..18e1f598ddb 100644 --- a/torchgeo/samplers/constants.py +++ b/torchgeo/samplers/constants.py @@ -13,5 +13,8 @@ class Units(Enum): :class:`~torchgeo.samplers.BatchGeoSampler`. """ + #: Units in number of pixels PIXELS = auto() + + #: Units of :term:`coordinate reference system (CRS)` CRS = auto() From 3104706358a1e84912f2eacf7ea20182f3eb3b8f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 22 Feb 2022 21:33:12 -0600 Subject: [PATCH 14/14] mypy fixes --- tests/samplers/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 80edf43845e..952f36dd15d 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -76,7 +76,7 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: class TestRandomBatchGeoSampler: @pytest.fixture(scope="class") - def dataset(self): + def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) ds.index.insert(1, (0, 100, 200, 300, 400, 500))