-
Notifications
You must be signed in to change notification settings - Fork 385
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding DOFA-Net model * Convert git submodule to single file * Get rid of __main__ code blocks * Add WeightsEnum * Add documentation * Add CSV file * Update link to transform implementation * Add unit tests * Test model forward function * Complete test coverage * Add type hints * Solve most type issues * Solve remaining type issues * Undo sorting * Simplifed docs * Fix loading of real weights * OFAViT -> DOFA * Remove redundant helper function * Add units for wavelengths * wave_list -> wavelengths * wvs -> wavelengths * img_feat -> x * Sorting * inter_dim is not used * Rename embedding layer * Simpler name for position embedding * wv_planes -> dynamic_embed_dim * make weight init a hidden method * Simpler model init * Use permalink * Document GEO-Bench performance * Simpler test * More columns * Update correct column numbers * Modified datasets * Add large weights * Update __all__ --------- Co-authored-by: Adam J. Stewart <[email protected]>
- Loading branch information
1 parent
d030044
commit 2849944
Showing
5 changed files
with
731 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Weight,Source,Citation,License,m-bigearthnet,m-forestnet,m-brick-kiln,m-pv4ger,m-so2sat,m-eurosat,m-pv4ger-seg,m-nz-cattle,m-NeonTree,m-cashew-plant,m-SA-crop,m-chesapeake | ||
DOFABase16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,63.8,45.3,94.7,96.9,52.1,92.2,94.7,81.6,58.6,48.3,31.3,65.4 | ||
DOFALarge16_Weights.DOFA_MAE,`link <https://github.com/zhu-xlab/DOFA>`__,`link <https://arxiv.org/abs/2403.15356>`__,CC-BY-4.0,64.4,47.4,95.1,97.3,59.3,93.8,95.0,81.7,59.1,53.8,32.1,66.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
from pathlib import Path | ||
from typing import Any | ||
|
||
import pytest | ||
import torch | ||
import torchvision | ||
from _pytest.fixtures import SubRequest | ||
from pytest import MonkeyPatch | ||
from torchvision.models._api import WeightsEnum | ||
|
||
from torchgeo.models import ( | ||
DOFA, | ||
DOFABase16_Weights, | ||
DOFALarge16_Weights, | ||
dofa_base_patch16_224, | ||
dofa_huge_patch16_224, | ||
dofa_large_patch16_224, | ||
dofa_small_patch16_224, | ||
) | ||
|
||
|
||
def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: | ||
state_dict: dict[str, Any] = torch.load(url) | ||
return state_dict | ||
|
||
|
||
class TestDOFA: | ||
@pytest.mark.parametrize( | ||
"wavelengths", | ||
[ | ||
# Gaofen | ||
[0.443, 0.565, 0.763, 0.765, 0.910], | ||
# NAIP | ||
[0.640, 0.560, 0.480], | ||
[0.480, 0.560, 0.640, 0.810], | ||
# Sentinel-1 | ||
[5.405], | ||
[5.405, 5.405], | ||
# Sentinel-2 | ||
[ | ||
0.443, | ||
0.490, | ||
0.560, | ||
0.665, | ||
0.705, | ||
0.740, | ||
0.783, | ||
0.842, | ||
0.865, | ||
0.945, | ||
1.375, | ||
1.610, | ||
2.190, | ||
], | ||
], | ||
) | ||
def test_dofa(self, wavelengths: list[float]) -> None: | ||
batch_size = 2 | ||
num_channels = len(wavelengths) | ||
num_classes = 10 | ||
global_pool = num_channels % 2 == 0 | ||
model = DOFA( | ||
embed_dim=384, | ||
depth=12, | ||
num_heads=6, | ||
num_classes=num_classes, | ||
global_pool=global_pool, | ||
) | ||
batch = torch.randn([batch_size, num_channels, 224, 224]) | ||
out = model(batch, wavelengths) | ||
assert out.shape == torch.Size([batch_size, num_classes]) | ||
|
||
|
||
class TestDOFASmall16: | ||
def test_dofa(self) -> None: | ||
dofa_small_patch16_224() | ||
|
||
|
||
class TestDOFABase16: | ||
@pytest.fixture(params=[*DOFABase16_Weights]) | ||
def weights(self, request: SubRequest) -> WeightsEnum: | ||
return request.param | ||
|
||
@pytest.fixture | ||
def mocked_weights( | ||
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum | ||
) -> WeightsEnum: | ||
path = tmp_path / f"{weights}.pth" | ||
model = dofa_base_patch16_224() | ||
torch.save(model.state_dict(), path) | ||
try: | ||
monkeypatch.setattr(weights.value, "url", str(path)) | ||
except AttributeError: | ||
monkeypatch.setattr(weights, "url", str(path)) | ||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) | ||
return weights | ||
|
||
def test_dofa(self) -> None: | ||
dofa_base_patch16_224() | ||
|
||
def test_dofa_weights(self, mocked_weights: WeightsEnum) -> None: | ||
dofa_base_patch16_224(weights=mocked_weights) | ||
|
||
def test_transforms(self, mocked_weights: WeightsEnum) -> None: | ||
c = 4 | ||
sample = { | ||
"image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) | ||
} | ||
mocked_weights.transforms(sample) | ||
|
||
@pytest.mark.slow | ||
def test_dofa_download(self, weights: WeightsEnum) -> None: | ||
dofa_base_patch16_224(weights=weights) | ||
|
||
|
||
class TestDOFALarge16: | ||
@pytest.fixture(params=[*DOFALarge16_Weights]) | ||
def weights(self, request: SubRequest) -> WeightsEnum: | ||
return request.param | ||
|
||
@pytest.fixture | ||
def mocked_weights( | ||
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum | ||
) -> WeightsEnum: | ||
path = tmp_path / f"{weights}.pth" | ||
model = dofa_large_patch16_224() | ||
torch.save(model.state_dict(), path) | ||
try: | ||
monkeypatch.setattr(weights.value, "url", str(path)) | ||
except AttributeError: | ||
monkeypatch.setattr(weights, "url", str(path)) | ||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) | ||
return weights | ||
|
||
def test_dofa(self) -> None: | ||
dofa_large_patch16_224() | ||
|
||
def test_dofa_weights(self, mocked_weights: WeightsEnum) -> None: | ||
dofa_large_patch16_224(weights=mocked_weights) | ||
|
||
def test_transforms(self, mocked_weights: WeightsEnum) -> None: | ||
c = 4 | ||
sample = { | ||
"image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) | ||
} | ||
mocked_weights.transforms(sample) | ||
|
||
@pytest.mark.slow | ||
def test_dofa_download(self, weights: WeightsEnum) -> None: | ||
dofa_large_patch16_224(weights=weights) | ||
|
||
|
||
class TestDOFAHuge16: | ||
def test_dofa(self) -> None: | ||
dofa_huge_patch16_224() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.