Skip to content

Commit

Permalink
Change DEMs from mask to image (is_image=True)
Browse files Browse the repository at this point in the history
  • Loading branch information
David Meaux committed Dec 15, 2023
1 parent 464f01f commit 7f507e4
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 74 deletions.
4 changes: 2 additions & 2 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
Dataset,Type,Source,License,Size (px),Resolution (m)
`Aboveground Woody Biomass`_,Masks,"Landsat, LiDAR","CC-BY-4.0","40,000x40,000",30
`Aster Global DEM`_,Masks,Aster,"public domain","3,601x3,601",30
`Aster Global DEM`_,Digital Elevation Model,Aster,"public domain","3,601x3,601",30
`Canadian Building Footprints`_,Geometries,Bing Imagery,"ODbL-1.0",-,-
`Chesapeake Land Cover`_,"Imagery, Masks",NAIP,"CC-BY-4.0",-,1
`Global Mangrove Distribution`_,Masks,"Remote Sensing, In Situ Measurements","public domain",-,3
`Cropland Data Layer`_,Masks,Landsat,"public domain",-,30
`EDDMapS`_,Points,Citizen Scientists,-,-,-
`EnviroAtlas`_,"Imagery, Masks","NAIP, NLCD, OpenStreetMap","CC-BY-4.0",-,1
`Esri2020`_,Masks,Sentinel-2,"CC-BY-4.0",-,10
`EU-DEM`_,Masks,"Aster, SRTM, Russian Topomaps","CSCDA-ESA",-,25
`EU-DEM`_,Digital Elevation Model,"Aster, SRTM, Russian Topomaps","CSCDA-ESA",-,25
`GBIF`_,Points,Citizen Scientists,"CC0-1.0 OR CC-BY-4.0 OR CC-BY-NC-4.0",-,-
`GlobBiomass`_,Masks,Landsat,"CC-BY-4.0","45,000x45,000",100
`iNaturalist`_,Points,Citizen Scientists,-,-,-
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@
"\n",
"### `is_image`\n",
"\n",
"If your data only contains image files, as is the case with Sentinel-2, use `is_image = True`. If your data only contains segmentation masks, use `is_image = False` instead.\n",
"If your data only contains image files, as is the case with Sentinel-2, or a digital surface, such as a Digital Elevation Model, Digital Surface Model, Digital Terrain Model, or a raster of temperature values, use `is_image = True`. If your data only contains segmentation masks, such as land use or land cover classification, use `is_image = False` instead.\n",
"\n",
"### `separate_files`\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion tests/data/astergdem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
# remove old data
if os.path.exists(path):
os.remove(path)
# Create mask file
# Create image file
create_file(path, dtype="int32", num_channels=1)
files_to_zip.append(path)

Expand Down
2 changes: 1 addition & 1 deletion tests/data/eudem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
# remove old data
if os.path.exists(path):
os.remove(path)
# Create mask file
# Create image file
create_file(path, dtype="int32", num_channels=1)
files_to_zip.append(path)

Expand Down
9 changes: 1 addition & 8 deletions tests/datasets/test_astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["image"], torch.Tensor)

def test_and(self, dataset: AsterGDEM) -> None:
ds = dataset & dataset
Expand All @@ -55,13 +55,6 @@ def test_plot(self, dataset: AsterGDEM) -> None:
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: AsterGDEM) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: AsterGDEM) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
with pytest.raises(
Expand Down
9 changes: 1 addition & 8 deletions tests/datasets/test_eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_getitem(self, dataset: EUDEM) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["image"], torch.Tensor)

def test_extracted_already(self, dataset: EUDEM) -> None:
assert isinstance(dataset.paths, str)
Expand Down Expand Up @@ -70,13 +70,6 @@ def test_plot(self, dataset: EUDEM) -> None:
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: EUDEM) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: EUDEM) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
with pytest.raises(
Expand Down
21 changes: 21 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, cast

import numpy as np
import pytest
import segmentation_models_pytorch as smp
import timm
Expand Down Expand Up @@ -248,3 +249,23 @@ def test_freeze_decoder(self, model_name: str) -> None:
for param in model.model.segmentation_head.parameters()
]
)

@pytest.mark.parametrize(
"class_weights", [torch.tensor([1, 2, 3]), np.array([1, 2, 3]), [1, 2, 3]]
)
def test_classweights_valid(
self, class_weights: Any, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["class_weights"] = class_weights
sst = SemanticSegmentationTask(**model_kwargs)
assert isinstance(sst.loss.weight, torch.Tensor)
assert torch.equal(sst.loss.weight, torch.tensor([1.0, 2.0, 3.0]))
assert sst.loss.weight.dtype == torch.float32

@pytest.mark.parametrize("class_weights", [[], None])
def test_classweights_empty(
self, class_weights: Any, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["class_weights"] = class_weights
sst = SemanticSegmentationTask(**model_kwargs)
assert sst.loss.weight is None
43 changes: 18 additions & 25 deletions torchgeo/datasets/astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS

Expand Down Expand Up @@ -36,7 +37,8 @@ class AsterGDEM(RasterDataset):
.. versionadded:: 0.3
"""

is_image = False
is_image = True
all_bands = ["elevation"]
filename_glob = "ASTGTMV003_*_dem*"
filename_regex = r"""
(?P<name>[ASTGTMV003]{10})
Expand Down Expand Up @@ -74,8 +76,11 @@ def __init__(
self.paths = paths

self._verify()
bands = self.all_bands

super().__init__(paths, crs, res, transforms=transforms, cache=cache)
super().__init__(
paths, crs, res, bands=bands, transforms=transforms, cache=cache
)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
Expand All @@ -101,29 +106,17 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze()
ncols = 1

showing_predictions = "prediction" in sample
if showing_predictions:
prediction = sample["prediction"].squeeze()
ncols = 2

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4))

if showing_predictions:
axs[0].imshow(mask)
axs[0].axis("off")
axs[1].imshow(prediction)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(mask)
axs.axis("off")
if show_titles:
axs.set_title("Mask")
image = sample["image"][0]

image = torch.clamp(image, min=0, max=1)

fig, ax = plt.subplots(1, 1, figsize=(4, 4))

ax.imshow(image)
ax.axis("off")

if show_titles:
ax.set_title("Elevation")

if suptitle is not None:
plt.suptitle(suptitle)
Expand Down
43 changes: 18 additions & 25 deletions torchgeo/datasets/eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS

Expand Down Expand Up @@ -46,7 +47,8 @@ class EUDEM(RasterDataset):
.. versionadded:: 0.3
"""

is_image = False
is_image = True
all_bands = ["elevation"]
filename_glob = "eu_dem_v11_*.TIF"
zipfile_glob = "eu_dem_v11_*[A-Z0-9].zip"
filename_regex = "(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})"
Expand Down Expand Up @@ -114,8 +116,11 @@ def __init__(
self.checksum = checksum

self._verify()
bands = self.all_bands

super().__init__(paths, crs, res, transforms=transforms, cache=cache)
super().__init__(
paths, crs, res, bands=bands, transforms=transforms, cache=cache
)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
Expand Down Expand Up @@ -152,29 +157,17 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze()
ncols = 1

showing_predictions = "prediction" in sample
if showing_predictions:
pred = sample["prediction"].squeeze()
ncols = 2

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4))

if showing_predictions:
axs[0].imshow(mask)
axs[0].axis("off")
axs[1].imshow(pred)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(mask)
axs.axis("off")
if show_titles:
axs.set_title("Mask")
image = sample["image"][0]

image = torch.clamp(image, min=0, max=1)

fig, ax = plt.subplots(1, 1, figsize=(4, 4))

ax.imshow(image)
ax.axis("off")

if show_titles:
ax.set_title("Elevation")

if suptitle is not None:
plt.suptitle(suptitle)
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
(e.g. Landsat and CDL)
* Combine datasets for multiple image sources for multimodal learning or data fusion
(e.g. Landsat and Sentinel)
* Combine image and digital a digital surface (e.g., elevation, temperature,
pressure) and sample from both simultaneously (e.g. Sentinel-2 and an Aster
Global DEM tile)
These combinations require that all queries are present in *both* datasets,
and can be combined using an :class:`IntersectionDataset`:
Expand Down Expand Up @@ -342,7 +346,8 @@ class RasterDataset(GeoDataset):
#: ``start`` and ``stop`` groups.
date_format = "%Y%m%d"

#: True if dataset contains imagery, False if dataset contains mask
#: True if dataset contains imagery or a digital surface, False if dataset contains
#: a mask, that is classified or categorical data
is_image = True

#: True if data is stored in a separate file for each band, else False.
Expand Down
6 changes: 4 additions & 2 deletions torchgeo/datasets/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ def plot(
# Add masks
if show_feats in {"masks", "both"} and "masks" in sample:
mask = masks[i]
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
# _ type: ignore[no-untyped-call]
contours = find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
Expand Down Expand Up @@ -525,7 +526,8 @@ def plot(
# Add masks
if show_pred_masks:
mask = prediction_masks[i]
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
# _ type: ignore[no-untyped-call]
contours = find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
Expand Down

0 comments on commit 7f507e4

Please sign in to comment.