Skip to content

Commit

Permalink
Revert "Merge branch 'vers_working_branch' into geosampler_prechipping"
Browse files Browse the repository at this point in the history
This reverts commit c7f3e4c, reversing
changes made to d8cb4b2.
  • Loading branch information
sfalkena committed Sep 23, 2024
1 parent c7f3e4c commit c51a63b
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 83 deletions.
15 changes: 0 additions & 15 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -145,20 +144,6 @@ def test_weighted_sampling(self) -> None:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample1 = bbox
break

sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
35 changes: 0 additions & 35 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import geopandas as gpd
import pytest
import torch
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from rasterio.crs import CRS
Expand Down Expand Up @@ -223,21 +222,6 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
generator = torch.manual_seed(0)
sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample1 = bbox
break

sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -387,25 +371,6 @@ def test_point_data(self) -> None:
for _ in sampler:
continue

def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
generator = torch.manual_seed(2)
sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler1:
sample1 = bbox
print(sample1)
break

generator = torch.manual_seed(2)
sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler2:
sample2 = bbox
print(sample2)
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
6 changes: 1 addition & 5 deletions torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ def setup(self, stage: str) -> None:

if stage in ['fit']:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset,
self.patch_size,
self.batch_size,
self.length,
generator=generator,
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ['fit', 'validate']:
self.val_sampler = GridGeoSampler(
Expand Down
7 changes: 1 addition & 6 deletions torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand Down Expand Up @@ -98,11 +97,9 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: random number generator
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.generator = generator

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
Expand Down Expand Up @@ -147,9 +144,7 @@ def __iter__(self) -> Iterator[list[BoundingBox]]:
# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
batch.append(bounding_box)

yield batch
Expand Down
18 changes: 3 additions & 15 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import abc
from collections.abc import Callable, Iterable, Iterator
from functools import partial

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -211,7 +210,6 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -238,16 +236,13 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: The random generator used for sampling.
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)

self.generator = generator
self.length = 0
self.hits = []
areas = []
Expand Down Expand Up @@ -309,7 +304,7 @@ def get_chips(self) -> GeoDataFrame:
bounds = BoundingBox(*hit.bounds)

# Choose a random index within that tile
bbox = get_random_bounding_box(bounds, self.size, self.res, self.generator)
bbox = get_random_bounding_box(bounds, self.size, self.res)
minx, maxx, miny, maxy, mint, maxt = tuple(bbox)
chip = {
'geometry': box(minx, miny, maxx, maxy),
Expand Down Expand Up @@ -452,11 +447,7 @@ class PreChippedGeoSampler(GeoSampler):
"""

def __init__(
self,
dataset: GeoDataset,
roi: BoundingBox | None = None,
shuffle: bool = False,
generator: torch.Generator | None = None,
self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -467,12 +458,9 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
shuffle: if True, reshuffle data at every epoch
generator: The random number generator used in combination with shuffle.
"""
super().__init__(dataset, roi)
self.shuffle = shuffle
self.generator = generator

self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
Expand All @@ -489,7 +477,7 @@ def get_chips(self) -> GeoDataFrame:
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
generator = partial(torch.randperm, generator=self.generator)
generator = torch.randperm

print('generating samples... ')
chips = []
Expand Down
10 changes: 3 additions & 7 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]:


def get_random_bounding_box(
bounds: BoundingBox,
size: tuple[float, float] | float,
res: float,
generator: torch.Generator | None = None,
bounds: BoundingBox, size: tuple[float, float] | float, res: float
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.
Expand All @@ -53,7 +50,6 @@ def get_random_bounding_box(
bounds: the larger bounding box to sample from
size: the size of the bounding box to sample
res: the resolution of the image
generator: random number generator
Returns:
randomly sampled bounding box from the extent of the input
Expand All @@ -68,8 +64,8 @@ def get_random_bounding_box(
miny = bounds.miny

# Use an integer multiple of res to avoid resampling
minx += int(torch.rand(1, generator=generator).item() * width) * res
miny += int(torch.rand(1, generator=generator).item() * height) * res
minx += int(torch.rand(1).item() * width) * res
miny += int(torch.rand(1).item() * height) * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]
Expand Down

0 comments on commit c51a63b

Please sign in to comment.