Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test model training #292

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion libs/spandrel/spandrel/__helpers/model_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def __init__(

self._purpose: Literal["SR", "FaceSR", "Restoration"] = purpose

self._call_fn = call_fn or (lambda model, image: model(image))
self._call_fn: Callable[[T, Tensor], Tensor] = call_fn or (
lambda model, image: model(image)
)

@property
@override
Expand Down
5 changes: 5 additions & 0 deletions tests/test_ATD.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -55,6 +56,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(ATDArch(), ATD())


def test_101_ATD_light_SRx2_scratch(snapshot):
file = ModelFile.from_url(
"https://drive.google.com/file/d/1ZxK7gMJXgeyHgeOaKbzpXtoElDmRWKkU/view?usp=drive_link",
Expand Down
6 changes: 6 additions & 0 deletions tests/test_AdaCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
skip_if_unchanged,
)

Expand Down Expand Up @@ -50,6 +51,11 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
# TODO: fix training
assert_training(AdaCodeArch(), AdaCode())


def test_AdaCode_SR_X2_model_g(snapshot):
file = ModelFile.from_url(
"https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X2_model_g.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_CRAFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -51,6 +52,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(CRAFTArch(), CRAFT())


def test_CRAFT_x2(snapshot):
file = ModelFile.from_url_zip(
"https://drive.google.com/file/d/13wAmc93BPeBUBQ24zUZOuUpdBFG2aAY5/view",
Expand Down
6 changes: 6 additions & 0 deletions tests/test_CodeFormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand All @@ -31,6 +32,11 @@ def test_size_requirements():
assert_size_requirements(file.load_model(), max_size=512)


def test_train():
# TODO: fix training
assert_training(CodeFormerArch(), CodeFormer())


def test_CodeFormer(snapshot):
file = ModelFile.from_url(
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_Compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -39,6 +40,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(CompactArch(), Compact())


def test_Compact_realesr_general_x4v3(snapshot):
file = ModelFile.from_url(
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_DAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -48,6 +49,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(DATArch(), DAT(), batch_size=2)


def test_DAT_S_x4(snapshot):
file = ModelFile.from_url(
"https://github.com/OpenModelDB/model-hub/releases/download/dat/4x-DAT_S.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_DCTLSA.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -36,6 +37,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(DCTLSAArch(), DCTLSA())


def test_x4(snapshot):
file = ModelFile.from_url(
"https://github.com/zengkun301/DCTLSA/raw/main/pretrained/X4.pt",
Expand Down
6 changes: 6 additions & 0 deletions tests/test_DDColor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TestImage,
assert_image_inference,
assert_loads_correctly,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -63,6 +64,11 @@ def test_load():
)


def test_train():
# TODO: fix training
assert_training(DDColorArch(), DDColor())


def test_DDColor_paper_tiny(snapshot):
file = ModelFile.from_url(
"https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_paper_tiny.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_DITN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -45,6 +46,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(DITNArch(), DITN())


def test_DITN_Real_GAN_x4(snapshot):
file = ModelFile.from_url(
"https://drive.google.com/file/d/12y6WjNowBkJ982fMql_yj6zBpwPKuhV2/view?usp=drive_link",
Expand Down
6 changes: 6 additions & 0 deletions tests/test_DRCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -39,6 +40,11 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
# TODO: fix training
assert_training(DRCTArch(), DRCT())


def test_community_model(snapshot):
file = ModelFile.from_url(
"https://github.com/Phhofm/models/releases/download/4xRealWebPhoto_v4_drct-l/4xRealWebPhoto_v4_drct-l.pth",
Expand Down
5 changes: 5 additions & 0 deletions tests/test_DRUNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -35,6 +36,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(DRUNetArch(), DRUNet())


def test_drunet_color(snapshot):
file = ModelFile.from_url(
"https://github.com/cszn/KAIR/releases/download/v1.0/drunet_color.pth",
Expand Down
5 changes: 5 additions & 0 deletions tests/test_DnCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand All @@ -29,6 +30,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(DnCNNArch(), DnCNN())


def test_dncnn_color_blind(snapshot):
file = ModelFile.from_url(
"https://github.com/cszn/KAIR/releases/download/v1.0/dncnn_color_blind.pth",
Expand Down
5 changes: 5 additions & 0 deletions tests/test_ESRGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -41,6 +42,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(ESRGANArch(), ESRGAN())


def test_ESRGAN_community(snapshot):
file = ModelFile.from_url(
"https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/1x-Anti-Aliasing.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_FBCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -41,6 +42,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(FBCNNArch(), FBCNN())


def test_FBCNN_color(snapshot):
file = ModelFile.from_url(
"https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/fbcnn_color.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_FFTformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand All @@ -32,6 +33,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(FFTformerArch(), FFTformer())


def test_fftformer_GoPro(snapshot):
file = ModelFile.from_url(
"https://github.com/kkkls/FFTformer/releases/download/pretrain_model/fftformer_GoPro.pth",
Expand Down
5 changes: 5 additions & 0 deletions tests/test_FeMaSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -49,6 +50,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(FeMaSRArch(), FeMaSR())


def test_FeMaSR_1x(snapshot):
file = ModelFile.from_url(
"https://github.com/chaofengc/FeMaSR/releases/download/v0.1-pretrain_models/FeMaSR_HRP_model_g.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_GFPGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ModelFile,
TestImage,
assert_image_inference,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand All @@ -19,6 +20,10 @@ def test_load():
)


def test_train():
assert_training(GFPGANArch(), GFPGAN())


def test_GFPGAN_1_2(snapshot):
file = ModelFile.from_url(
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_GRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -135,6 +136,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(GRLArch(), GRL())


# def test_GRL_dn_grl_tiny_c1(snapshot):
# file = ModelFile.from_url(
# "https://github.com/ofsoundof/GRL-Image-Restoration/releases/download/v1.0.0/dn_grl_tiny_c1.ckpt"
Expand Down
6 changes: 6 additions & 0 deletions tests/test_HAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand Down Expand Up @@ -57,6 +58,11 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
# TODO: fix training
assert_training(HATArch(), HAT())


def test_HAT_S_2x(snapshot):
file = ModelFile.from_url(
"https://github.com/OpenModelDB/model-hub/releases/download/hat/2x-HAT-S_SR.pth",
Expand Down
5 changes: 5 additions & 0 deletions tests/test_HVICIDNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_image_inference,
assert_loads_correctly,
assert_size_requirements,
assert_training,
disallowed_props,
skip_if_unchanged,
)
Expand All @@ -31,6 +32,10 @@ def test_size_requirements():
assert_size_requirements(file.load_model())


def test_train():
assert_training(HVICIDNetArch(), HVICIDNet())


def test_LOLv1(snapshot):
file = ModelFile.from_url(
"https://drive.google.com/file/d/1KbRpPL9-TfBDmZoV6S05pTvMB0sm_OwO/view?usp=sharing",
Expand Down
Loading
Loading