Skip to content

Commit

Permalink
Add DOFA model (#1903)
Browse files Browse the repository at this point in the history
* adding DOFA-Net model

* Convert git submodule to single file

* Get rid of __main__ code blocks

* Add WeightsEnum

* Add documentation

* Add CSV file

* Update link to transform implementation

* Add unit tests

* Test model forward function

* Complete test coverage

* Add type hints

* Solve most type issues

* Solve remaining type issues

* Undo sorting

* Simplifed docs

* Fix loading of real weights

* OFAViT -> DOFA

* Remove redundant helper function

* Add units for wavelengths

* wave_list -> wavelengths

* wvs -> wavelengths

* img_feat -> x

* Sorting

* inter_dim is not used

* Rename embedding layer

* Simpler name for position embedding

* wv_planes -> dynamic_embed_dim

* make weight init a hidden method

* Simpler model init

* Use permalink

* Document GEO-Bench performance

* Simpler test

* More columns

* Update correct column numbers

* Modified datasets

* Add large weights

* Update __all__

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
xiong-zhitong and adamjstewart authored Mar 25, 2024
1 parent d030044 commit 2849944
Show file tree
Hide file tree
Showing 5 changed files with 731 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/api/agnostic_pretrained_weights.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Weight,Source,Citation,License,m-bigearthnet,m-forestnet,m-brick-kiln,m-pv4ger,m-so2sat,m-eurosat,m-pv4ger-seg,m-nz-cattle,m-NeonTree,m-cashew-plant,m-SA-crop,m-chesapeake
DOFABase16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,63.8,45.3,94.7,96.9,52.1,92.2,94.7,81.6,58.6,48.3,31.3,65.4
DOFALarge16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,64.4,47.4,95.1,97.3,59.3,93.8,95.0,81.7,59.1,53.8,32.1,66.3
23 changes: 23 additions & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ Change Star
.. autoclass:: ChangeStarFarSeg
.. autoclass:: ChangeMixin

DOFA
^^^^

.. autoclass:: DOFA
.. autofunction:: dofa_small_patch16_224
.. autofunction:: dofa_base_patch16_224
.. autofunction:: dofa_large_patch16_224
.. autofunction:: dofa_huge_patch16_224
.. autoclass:: DOFABase16_Weights
.. autoclass:: DOFALarge16_Weights

FarSeg
^^^^^^

Expand Down Expand Up @@ -63,6 +74,18 @@ Utility Functions
Pretrained Weights
^^^^^^^^^^^^^^^^^^

Sensor-Agnostic
---------------

These weights can be used with imagery from any satellite/sensor.

.. csv-table::
:widths: 45 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
:header-rows: 1
:align: center
:file: agnostic_pretrained_weights.csv


NAIP
----

Expand Down
158 changes: 158 additions & 0 deletions tests/models/test_dofa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from pathlib import Path
from typing import Any

import pytest
import torch
import torchvision
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum

from torchgeo.models import (
DOFA,
DOFABase16_Weights,
DOFALarge16_Weights,
dofa_base_patch16_224,
dofa_huge_patch16_224,
dofa_large_patch16_224,
dofa_small_patch16_224,
)


def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict


class TestDOFA:
@pytest.mark.parametrize(
"wavelengths",
[
# Gaofen
[0.443, 0.565, 0.763, 0.765, 0.910],
# NAIP
[0.640, 0.560, 0.480],
[0.480, 0.560, 0.640, 0.810],
# Sentinel-1
[5.405],
[5.405, 5.405],
# Sentinel-2
[
0.443,
0.490,
0.560,
0.665,
0.705,
0.740,
0.783,
0.842,
0.865,
0.945,
1.375,
1.610,
2.190,
],
],
)
def test_dofa(self, wavelengths: list[float]) -> None:
batch_size = 2
num_channels = len(wavelengths)
num_classes = 10
global_pool = num_channels % 2 == 0
model = DOFA(
embed_dim=384,
depth=12,
num_heads=6,
num_classes=num_classes,
global_pool=global_pool,
)
batch = torch.randn([batch_size, num_channels, 224, 224])
out = model(batch, wavelengths)
assert out.shape == torch.Size([batch_size, num_classes])


class TestDOFASmall16:
def test_dofa(self) -> None:
dofa_small_patch16_224()


class TestDOFABase16:
@pytest.fixture(params=[*DOFABase16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = dofa_base_patch16_224()
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_dofa(self) -> None:
dofa_base_patch16_224()

def test_dofa_weights(self, mocked_weights: WeightsEnum) -> None:
dofa_base_patch16_224(weights=mocked_weights)

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = 4
sample = {
"image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_dofa_download(self, weights: WeightsEnum) -> None:
dofa_base_patch16_224(weights=weights)


class TestDOFALarge16:
@pytest.fixture(params=[*DOFALarge16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param

@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = dofa_large_patch16_224()
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights

def test_dofa(self) -> None:
dofa_large_patch16_224()

def test_dofa_weights(self, mocked_weights: WeightsEnum) -> None:
dofa_large_patch16_224(weights=mocked_weights)

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = 4
sample = {
"image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_dofa_download(self, weights: WeightsEnum) -> None:
dofa_large_patch16_224(weights=weights)


class TestDOFAHuge16:
def test_dofa(self) -> None:
dofa_huge_patch16_224()
16 changes: 16 additions & 0 deletions torchgeo/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

from .api import get_model, get_model_weights, get_weight, list_models
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
from .dofa import (
DOFA,
DOFABase16_Weights,
DOFALarge16_Weights,
dofa_base_patch16_224,
dofa_huge_patch16_224,
dofa_large_patch16_224,
dofa_small_patch16_224,
)
from .farseg import FarSeg
from .fcn import FCN
from .fcsiam import FCSiamConc, FCSiamDiff
Expand All @@ -18,6 +27,11 @@
"ChangeMixin",
"ChangeStar",
"ChangeStarFarSeg",
"DOFA",
"dofa_small_patch16_224",
"dofa_base_patch16_224",
"dofa_large_patch16_224",
"dofa_huge_patch16_224",
"FarSeg",
"FCN",
"FCSiamConc",
Expand All @@ -28,6 +42,8 @@
"swin_v2_b",
"vit_small_patch16_224",
# weights
"DOFABase16_Weights",
"DOFALarge16_Weights",
"ResNet50_Weights",
"ResNet18_Weights",
"Swin_V2_B_Weights",
Expand Down
Loading

0 comments on commit 2849944

Please sign in to comment.