Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update seg and class metrics #2554

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
- pytest>=6.1.2
- scikit-image>=0.22.0
- torch>=2.6
- torchmetrics>=0.10
- torchmetrics>=1.1.1
- torchvision>=0.18
exclude: (build|data|dist|logo|logs|output)/
- repo: https://github.com/pre-commit/mirrors-prettier
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ dependencies = [
"timm>=0.4.12",
# torch 1.13+ required by torchvision
"torch>=1.13",
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
"torchmetrics>=0.10",
"torchmetrics>=1.1.1",
# torchvision 0.14+ required for torchvision.models.swin_v2_b
"torchvision>=0.14",
]
Expand Down
2 changes: 1 addition & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0
shapely==1.8.0
timm==0.4.12
torch==1.13.0
torchmetrics==0.10.0
torchmetrics==1.1.1
torchvision==0.14.0

# datasets
Expand Down
16 changes: 4 additions & 12 deletions tests/datamodules/test_fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import os

import matplotlib.pyplot as plt
import pytest

from torchgeo.datamodules import FAIR1MDataModule
Expand All @@ -26,17 +25,10 @@ def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('validate')
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('test')
next(iter(datamodule.test_dataloader()))

def test_predict_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('predict')
next(iter(datamodule.predict_dataloader()))

def test_plot(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = {
'image': batch['image'][0],
'boxes': batch['boxes'][0],
'label': batch['label'][0],
}
datamodule.plot(sample)
plt.close()
22 changes: 2 additions & 20 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
from _pytest.fixtures import SubRequest
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor

Expand All @@ -31,12 +29,9 @@ def __init__(
self.res = 1

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()


class CustomGeoDataModule(GeoDataModule):
def __init__(self) -> None:
Expand Down Expand Up @@ -68,14 +63,11 @@ def __init__(
self.length = length

def __getitem__(self, index: int) -> dict[str, Tensor]:
return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)}
return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)}

def __len__(self) -> int:
return self.length

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()


class CustomNonGeoDataModule(NonGeoDataModule):
def __init__(self) -> None:
Expand Down Expand Up @@ -133,11 +125,6 @@ def test_predict(self, datamodule: CustomGeoDataModule) -> None:
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomGeoDataModule()
msg = r'CustomGeoDataModule\.setup must define one of '
Expand Down Expand Up @@ -235,11 +222,6 @@ def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomNonGeoDataModule()
msg = r'CustomNonGeoDataModule\.setup must define one of '
Expand Down
9 changes: 0 additions & 9 deletions tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

import os

import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import USAVarsDataModule
from torchgeo.datasets import unbind_samples


class TestUSAVarsDataModule:
Expand Down Expand Up @@ -41,10 +39,3 @@ def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.test_dataloader()) == 1
batch = next(iter(datamodule.test_dataloader()))
assert batch['image'].shape[0] == datamodule.batch_size

def test_plot(self, datamodule: USAVarsDataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
37 changes: 37 additions & 0 deletions tests/datamodules/test_xview2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest

from torchgeo.datamodules import XView2DataModule


class TestXView2DataModule:
@pytest.fixture
def datamodule(self) -> XView2DataModule:
root = os.path.join('tests', 'data', 'xview2')
batch_size = 1
num_workers = 0
dm = XView2DataModule(
root=root, batch_size=batch_size, num_workers=num_workers, val_split_pct=0.5
)
dm.prepare_data()
return dm

def test_train_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('fit')
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('validate')
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('test')
next(iter(datamodule.test_dataloader()))

def test_predict_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('predict')
next(iter(datamodule.predict_dataloader()))
93 changes: 14 additions & 79 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import timm
import torch
import torch.nn as nn
import torchvision
from lightning.pytorch import Trainer
from pytest import MonkeyPatch
from torch.nn.modules import Module
Expand All @@ -19,7 +20,7 @@
EuroSATDataModule,
MisconfigurationException,
)
from torchgeo.datasets import BigEarthNet, EuroSAT, RGBBandsMissingError
from torchgeo.datasets import BigEarthNet, EuroSAT
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
Expand Down Expand Up @@ -55,12 +56,9 @@
return ClassificationTestModel(**kwargs)


def plot(*args: Any, **kwargs: Any) -> None:
return None


def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
raise RGBBandsMissingError()
def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestClassificationTask:
Expand Down Expand Up @@ -103,13 +101,13 @@
'1',
]

main(['fit', *args])
main(['fit'] + args)

Check failure on line 104 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:104:14: RUF005 Consider `['fit', *args]` instead of concatenation
try:
main(['test', *args])
main(['test'] + args)

Check failure on line 106 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:106:18: RUF005 Consider `['test', *args]` instead of concatenation
except MisconfigurationException:
pass
try:
main(['predict', *args])
main(['predict'] + args)

Check failure on line 110 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:110:18: RUF005 Consider `['predict', *args]` instead of concatenation
except MisconfigurationException:
pass

Expand All @@ -119,11 +117,7 @@

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model( # type: ignore[attr-defined]
Expand All @@ -134,6 +128,7 @@
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)
return weights

def test_weight_file(self, checkpoint: str) -> None:
Expand Down Expand Up @@ -183,34 +178,6 @@
with pytest.raises(ValueError, match=match):
ClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(EuroSATDataModule, 'plot', plot)
datamodule = EuroSATDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
)
model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(EuroSATDataModule, 'plot', plot_missing_bands)
datamodule = EuroSATDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
)
model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictClassificationDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
Expand All @@ -237,7 +204,7 @@

class TestMultiLabelClassificationTask:
@pytest.mark.parametrize(
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2', 'treesatai']
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2']
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
Expand All @@ -259,13 +226,13 @@
'1',
]

main(['fit', *args])
main(['fit'] + args)

Check failure on line 229 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:229:14: RUF005 Consider `['fit', *args]` instead of concatenation
try:
main(['test', *args])
main(['test'] + args)

Check failure on line 231 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:231:18: RUF005 Consider `['test', *args]` instead of concatenation
except MisconfigurationException:
pass
try:
main(['predict', *args])
main(['predict'] + args)

Check failure on line 235 in tests/trainers/test_classification.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF005)

tests/trainers/test_classification.py:235:18: RUF005 Consider `['predict', *args]` instead of concatenation
except MisconfigurationException:
pass

Expand All @@ -274,38 +241,6 @@
with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
Expand Down
Loading
Loading