diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4461cd65..10d19a10 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -53,6 +53,7 @@ jobs: python-version: '3.9' cache: 'pip' - run: pip install .[typecheck] + - run: pip install .[test] - run: pyright src/ - run: pyright tests/ - run: pyright scripts/ diff --git a/pyproject.toml b/pyproject.toml index 4b04a413..83db3748 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dynamic = ["version"] build = ["setuptools>=46.4.0", "wheel", "build", "twine"] lint = ["ruff==0.1.4"] typecheck = ["pyright==1.1.335"] -test = ["pytest==7.4.0", "syrupy==4.6.0"] +test = ["pytest==7.4.0", "syrupy==4.6.0", "opencv-python==4.8.1.78"] # [tool.setuptools.dynamic] # version = { attr = "spandrel.VERSION" } @@ -73,4 +73,4 @@ ignore = [ "**/tests/**/*" = ["N802"] [tool.pytest.ini_options] -filterwarnings = ["ignore::DeprecationWarning"] +filterwarnings = ["ignore::DeprecationWarning", "ignore::UserWarning"] diff --git a/src/spandrel/architectures/ESRGAN/arch/RRDB.py b/src/spandrel/architectures/ESRGAN/arch/RRDB.py index ffbc9187..daf70662 100644 --- a/src/spandrel/architectures/ESRGAN/arch/RRDB.py +++ b/src/spandrel/architectures/ESRGAN/arch/RRDB.py @@ -45,6 +45,7 @@ def __init__( super().__init__() self.shuffle_factor = shuffle_factor + self.scale = scale upsample_block = { "upconv": B.upconv_block, diff --git a/src/spandrel/architectures/SwinIR/__init__.py b/src/spandrel/architectures/SwinIR/__init__.py index e9556e25..9f7df070 100644 --- a/src/spandrel/architectures/SwinIR/__init__.py +++ b/src/spandrel/architectures/SwinIR/__init__.py @@ -178,6 +178,7 @@ def load(state_dict: StateDict) -> SRModelDescriptor[SwinIR]: num_feat=num_feat, num_in_ch=num_in_ch, num_out_ch=num_out_ch, + start_unshuffle=start_unshuffle, ) head_length = len(depths) # type: ignore diff --git a/src/spandrel/architectures/SwinIR/arch/SwinIR.py b/src/spandrel/architectures/SwinIR/arch/SwinIR.py index 2c1ecd92..de156308 100644 --- a/src/spandrel/architectures/SwinIR/arch/SwinIR.py +++ b/src/spandrel/architectures/SwinIR/arch/SwinIR.py @@ -835,9 +835,11 @@ def __init__( img_range=1.0, upsampler="", resi_connection="1conv", + start_unshuffle=1, **kwargs, ): super().__init__() + self.start_unshuffle = start_unshuffle num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 diff --git a/tests/images/inputs/16x16.png b/tests/images/inputs/16x16.png new file mode 100644 index 00000000..743a418a Binary files /dev/null and b/tests/images/inputs/16x16.png differ diff --git a/tests/images/inputs/32x32.png b/tests/images/inputs/32x32.png new file mode 100644 index 00000000..a8150183 Binary files /dev/null and b/tests/images/inputs/32x32.png differ diff --git a/tests/images/inputs/64x64.png b/tests/images/inputs/64x64.png new file mode 100644 index 00000000..c99fc293 Binary files /dev/null and b/tests/images/inputs/64x64.png differ diff --git a/tests/images/outputs/16x16/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png b/tests/images/outputs/16x16/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png new file mode 100644 index 00000000..c765cf86 Binary files /dev/null and b/tests/images/outputs/16x16/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png differ diff --git a/tests/images/outputs/16x16/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png b/tests/images/outputs/16x16/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png new file mode 100644 index 00000000..b9935486 Binary files /dev/null and b/tests/images/outputs/16x16/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png differ diff --git a/tests/images/outputs/16x16/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png b/tests/images/outputs/16x16/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png new file mode 100644 index 00000000..94d707f7 Binary files /dev/null and b/tests/images/outputs/16x16/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png differ diff --git a/tests/images/outputs/16x16/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png b/tests/images/outputs/16x16/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png new file mode 100644 index 00000000..77f9fca6 Binary files /dev/null and b/tests/images/outputs/16x16/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png differ diff --git a/tests/images/outputs/16x16/1x-Anti-Aliasing.png b/tests/images/outputs/16x16/1x-Anti-Aliasing.png new file mode 100644 index 00000000..d45a41a4 Binary files /dev/null and b/tests/images/outputs/16x16/1x-Anti-Aliasing.png differ diff --git a/tests/images/outputs/16x16/2x-AniScale.png b/tests/images/outputs/16x16/2x-AniScale.png new file mode 100644 index 00000000..b95b23b4 Binary files /dev/null and b/tests/images/outputs/16x16/2x-AniScale.png differ diff --git a/tests/images/outputs/16x16/2xHFA2kAVCOmniSR.png b/tests/images/outputs/16x16/2xHFA2kAVCOmniSR.png new file mode 100644 index 00000000..2dec4de8 Binary files /dev/null and b/tests/images/outputs/16x16/2xHFA2kAVCOmniSR.png differ diff --git a/tests/images/outputs/16x16/4x-ardo.png b/tests/images/outputs/16x16/4x-ardo.png new file mode 100644 index 00000000..a85befb2 Binary files /dev/null and b/tests/images/outputs/16x16/4x-ardo.png differ diff --git a/tests/images/outputs/16x16/4xLexicaHAT.png b/tests/images/outputs/16x16/4xLexicaHAT.png new file mode 100644 index 00000000..371aea10 Binary files /dev/null and b/tests/images/outputs/16x16/4xLexicaHAT.png differ diff --git a/tests/images/outputs/16x16/BSRGAN.png b/tests/images/outputs/16x16/BSRGAN.png new file mode 100644 index 00000000..de2f3906 Binary files /dev/null and b/tests/images/outputs/16x16/BSRGAN.png differ diff --git a/tests/images/outputs/16x16/BSRGANx2.png b/tests/images/outputs/16x16/BSRGANx2.png new file mode 100644 index 00000000..45652023 Binary files /dev/null and b/tests/images/outputs/16x16/BSRGANx2.png differ diff --git a/tests/images/outputs/16x16/RealESRGAN_x2plus.png b/tests/images/outputs/16x16/RealESRGAN_x2plus.png new file mode 100644 index 00000000..d67c6122 Binary files /dev/null and b/tests/images/outputs/16x16/RealESRGAN_x2plus.png differ diff --git a/tests/images/outputs/16x16/RealESRGAN_x4plus.png b/tests/images/outputs/16x16/RealESRGAN_x4plus.png new file mode 100644 index 00000000..7612c826 Binary files /dev/null and b/tests/images/outputs/16x16/RealESRGAN_x4plus.png differ diff --git a/tests/images/outputs/16x16/RealESRGAN_x4plus_anime_6B.png b/tests/images/outputs/16x16/RealESRGAN_x4plus_anime_6B.png new file mode 100644 index 00000000..58500138 Binary files /dev/null and b/tests/images/outputs/16x16/RealESRGAN_x4plus_anime_6B.png differ diff --git a/tests/images/outputs/16x16/RealESRNet_x4plus.png b/tests/images/outputs/16x16/RealESRNet_x4plus.png new file mode 100644 index 00000000..37c40e0b Binary files /dev/null and b/tests/images/outputs/16x16/RealESRNet_x4plus.png differ diff --git a/tests/images/outputs/16x16/RealSR_DPED.png b/tests/images/outputs/16x16/RealSR_DPED.png new file mode 100644 index 00000000..f4e0cd16 Binary files /dev/null and b/tests/images/outputs/16x16/RealSR_DPED.png differ diff --git a/tests/images/outputs/16x16/RealSR_JPEG.png b/tests/images/outputs/16x16/RealSR_JPEG.png new file mode 100644 index 00000000..eff14fc3 Binary files /dev/null and b/tests/images/outputs/16x16/RealSR_JPEG.png differ diff --git a/tests/images/outputs/16x16/Swin2SR_ClassicalSR_X4_64.png b/tests/images/outputs/16x16/Swin2SR_ClassicalSR_X4_64.png new file mode 100644 index 00000000..2304a893 Binary files /dev/null and b/tests/images/outputs/16x16/Swin2SR_ClassicalSR_X4_64.png differ diff --git a/tests/images/outputs/16x16/realesr-general-x4v3.png b/tests/images/outputs/16x16/realesr-general-x4v3.png new file mode 100644 index 00000000..ec47d119 Binary files /dev/null and b/tests/images/outputs/16x16/realesr-general-x4v3.png differ diff --git a/tests/images/outputs/16x16/swift_srgan_2x.png b/tests/images/outputs/16x16/swift_srgan_2x.png new file mode 100644 index 00000000..c1ec29b6 Binary files /dev/null and b/tests/images/outputs/16x16/swift_srgan_2x.png differ diff --git a/tests/images/outputs/16x16/swift_srgan_4x.png b/tests/images/outputs/16x16/swift_srgan_4x.png new file mode 100644 index 00000000..b3173e85 Binary files /dev/null and b/tests/images/outputs/16x16/swift_srgan_4x.png differ diff --git a/tests/images/outputs/32x32/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png b/tests/images/outputs/32x32/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png new file mode 100644 index 00000000..6d327946 Binary files /dev/null and b/tests/images/outputs/32x32/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png differ diff --git a/tests/images/outputs/32x32/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png b/tests/images/outputs/32x32/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png new file mode 100644 index 00000000..587b821c Binary files /dev/null and b/tests/images/outputs/32x32/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png differ diff --git a/tests/images/outputs/32x32/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png b/tests/images/outputs/32x32/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png new file mode 100644 index 00000000..eee00e49 Binary files /dev/null and b/tests/images/outputs/32x32/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png differ diff --git a/tests/images/outputs/32x32/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png b/tests/images/outputs/32x32/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png new file mode 100644 index 00000000..31707453 Binary files /dev/null and b/tests/images/outputs/32x32/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png differ diff --git a/tests/images/outputs/32x32/1x-Anti-Aliasing.png b/tests/images/outputs/32x32/1x-Anti-Aliasing.png new file mode 100644 index 00000000..56b3e449 Binary files /dev/null and b/tests/images/outputs/32x32/1x-Anti-Aliasing.png differ diff --git a/tests/images/outputs/32x32/2x-AniScale.png b/tests/images/outputs/32x32/2x-AniScale.png new file mode 100644 index 00000000..bb58536d Binary files /dev/null and b/tests/images/outputs/32x32/2x-AniScale.png differ diff --git a/tests/images/outputs/32x32/2xHFA2kAVCOmniSR.png b/tests/images/outputs/32x32/2xHFA2kAVCOmniSR.png new file mode 100644 index 00000000..7abe69b1 Binary files /dev/null and b/tests/images/outputs/32x32/2xHFA2kAVCOmniSR.png differ diff --git a/tests/images/outputs/32x32/4x-ardo.png b/tests/images/outputs/32x32/4x-ardo.png new file mode 100644 index 00000000..4e65cd8b Binary files /dev/null and b/tests/images/outputs/32x32/4x-ardo.png differ diff --git a/tests/images/outputs/32x32/4xLexicaHAT.png b/tests/images/outputs/32x32/4xLexicaHAT.png new file mode 100644 index 00000000..3611b8c3 Binary files /dev/null and b/tests/images/outputs/32x32/4xLexicaHAT.png differ diff --git a/tests/images/outputs/32x32/BSRGAN.png b/tests/images/outputs/32x32/BSRGAN.png new file mode 100644 index 00000000..8d15996e Binary files /dev/null and b/tests/images/outputs/32x32/BSRGAN.png differ diff --git a/tests/images/outputs/32x32/BSRGANx2.png b/tests/images/outputs/32x32/BSRGANx2.png new file mode 100644 index 00000000..27ae4e38 Binary files /dev/null and b/tests/images/outputs/32x32/BSRGANx2.png differ diff --git a/tests/images/outputs/32x32/RealESRGAN_x2plus.png b/tests/images/outputs/32x32/RealESRGAN_x2plus.png new file mode 100644 index 00000000..fdd00463 Binary files /dev/null and b/tests/images/outputs/32x32/RealESRGAN_x2plus.png differ diff --git a/tests/images/outputs/32x32/RealESRGAN_x4plus.png b/tests/images/outputs/32x32/RealESRGAN_x4plus.png new file mode 100644 index 00000000..00e26d6e Binary files /dev/null and b/tests/images/outputs/32x32/RealESRGAN_x4plus.png differ diff --git a/tests/images/outputs/32x32/RealESRGAN_x4plus_anime_6B.png b/tests/images/outputs/32x32/RealESRGAN_x4plus_anime_6B.png new file mode 100644 index 00000000..91a0ae16 Binary files /dev/null and b/tests/images/outputs/32x32/RealESRGAN_x4plus_anime_6B.png differ diff --git a/tests/images/outputs/32x32/RealESRNet_x4plus.png b/tests/images/outputs/32x32/RealESRNet_x4plus.png new file mode 100644 index 00000000..2eda221a Binary files /dev/null and b/tests/images/outputs/32x32/RealESRNet_x4plus.png differ diff --git a/tests/images/outputs/32x32/RealSR_DPED.png b/tests/images/outputs/32x32/RealSR_DPED.png new file mode 100644 index 00000000..244d6404 Binary files /dev/null and b/tests/images/outputs/32x32/RealSR_DPED.png differ diff --git a/tests/images/outputs/32x32/RealSR_JPEG.png b/tests/images/outputs/32x32/RealSR_JPEG.png new file mode 100644 index 00000000..19572ff7 Binary files /dev/null and b/tests/images/outputs/32x32/RealSR_JPEG.png differ diff --git a/tests/images/outputs/32x32/Swin2SR_ClassicalSR_X4_64.png b/tests/images/outputs/32x32/Swin2SR_ClassicalSR_X4_64.png new file mode 100644 index 00000000..a3cb4f4b Binary files /dev/null and b/tests/images/outputs/32x32/Swin2SR_ClassicalSR_X4_64.png differ diff --git a/tests/images/outputs/32x32/realesr-general-x4v3.png b/tests/images/outputs/32x32/realesr-general-x4v3.png new file mode 100644 index 00000000..167e4660 Binary files /dev/null and b/tests/images/outputs/32x32/realesr-general-x4v3.png differ diff --git a/tests/images/outputs/32x32/swift_srgan_2x.png b/tests/images/outputs/32x32/swift_srgan_2x.png new file mode 100644 index 00000000..e0c27c17 Binary files /dev/null and b/tests/images/outputs/32x32/swift_srgan_2x.png differ diff --git a/tests/images/outputs/32x32/swift_srgan_4x.png b/tests/images/outputs/32x32/swift_srgan_4x.png new file mode 100644 index 00000000..9fd15edd Binary files /dev/null and b/tests/images/outputs/32x32/swift_srgan_4x.png differ diff --git a/tests/images/outputs/64x64/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png b/tests/images/outputs/64x64/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png new file mode 100644 index 00000000..ccc8b545 Binary files /dev/null and b/tests/images/outputs/64x64/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.png differ diff --git a/tests/images/outputs/64x64/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png b/tests/images/outputs/64x64/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png new file mode 100644 index 00000000..c1536748 Binary files /dev/null and b/tests/images/outputs/64x64/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.png differ diff --git a/tests/images/outputs/64x64/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png b/tests/images/outputs/64x64/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png new file mode 100644 index 00000000..c9123ec9 Binary files /dev/null and b/tests/images/outputs/64x64/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.png differ diff --git a/tests/images/outputs/64x64/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png b/tests/images/outputs/64x64/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png new file mode 100644 index 00000000..8dc216ec Binary files /dev/null and b/tests/images/outputs/64x64/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.png differ diff --git a/tests/images/outputs/64x64/1x-Anti-Aliasing.png b/tests/images/outputs/64x64/1x-Anti-Aliasing.png new file mode 100644 index 00000000..e47453f4 Binary files /dev/null and b/tests/images/outputs/64x64/1x-Anti-Aliasing.png differ diff --git a/tests/images/outputs/64x64/2x-AniScale.png b/tests/images/outputs/64x64/2x-AniScale.png new file mode 100644 index 00000000..2d72357b Binary files /dev/null and b/tests/images/outputs/64x64/2x-AniScale.png differ diff --git a/tests/images/outputs/64x64/2xHFA2kAVCOmniSR.png b/tests/images/outputs/64x64/2xHFA2kAVCOmniSR.png new file mode 100644 index 00000000..2f58190d Binary files /dev/null and b/tests/images/outputs/64x64/2xHFA2kAVCOmniSR.png differ diff --git a/tests/images/outputs/64x64/4x-ardo.png b/tests/images/outputs/64x64/4x-ardo.png new file mode 100644 index 00000000..f095eeff Binary files /dev/null and b/tests/images/outputs/64x64/4x-ardo.png differ diff --git a/tests/images/outputs/64x64/4xLexicaHAT.png b/tests/images/outputs/64x64/4xLexicaHAT.png new file mode 100644 index 00000000..2cddd2b1 Binary files /dev/null and b/tests/images/outputs/64x64/4xLexicaHAT.png differ diff --git a/tests/images/outputs/64x64/BSRGAN.png b/tests/images/outputs/64x64/BSRGAN.png new file mode 100644 index 00000000..d616a98c Binary files /dev/null and b/tests/images/outputs/64x64/BSRGAN.png differ diff --git a/tests/images/outputs/64x64/BSRGANx2.png b/tests/images/outputs/64x64/BSRGANx2.png new file mode 100644 index 00000000..5ecca5b4 Binary files /dev/null and b/tests/images/outputs/64x64/BSRGANx2.png differ diff --git a/tests/images/outputs/64x64/RealESRGAN_x2plus.png b/tests/images/outputs/64x64/RealESRGAN_x2plus.png new file mode 100644 index 00000000..5cfb8724 Binary files /dev/null and b/tests/images/outputs/64x64/RealESRGAN_x2plus.png differ diff --git a/tests/images/outputs/64x64/RealESRGAN_x4plus.png b/tests/images/outputs/64x64/RealESRGAN_x4plus.png new file mode 100644 index 00000000..e3dfc40e Binary files /dev/null and b/tests/images/outputs/64x64/RealESRGAN_x4plus.png differ diff --git a/tests/images/outputs/64x64/RealESRGAN_x4plus_anime_6B.png b/tests/images/outputs/64x64/RealESRGAN_x4plus_anime_6B.png new file mode 100644 index 00000000..f04bb848 Binary files /dev/null and b/tests/images/outputs/64x64/RealESRGAN_x4plus_anime_6B.png differ diff --git a/tests/images/outputs/64x64/RealESRNet_x4plus.png b/tests/images/outputs/64x64/RealESRNet_x4plus.png new file mode 100644 index 00000000..db5860ed Binary files /dev/null and b/tests/images/outputs/64x64/RealESRNet_x4plus.png differ diff --git a/tests/images/outputs/64x64/RealSR_DPED.png b/tests/images/outputs/64x64/RealSR_DPED.png new file mode 100644 index 00000000..50e55876 Binary files /dev/null and b/tests/images/outputs/64x64/RealSR_DPED.png differ diff --git a/tests/images/outputs/64x64/RealSR_JPEG.png b/tests/images/outputs/64x64/RealSR_JPEG.png new file mode 100644 index 00000000..5f2d60e8 Binary files /dev/null and b/tests/images/outputs/64x64/RealSR_JPEG.png differ diff --git a/tests/images/outputs/64x64/Swin2SR_ClassicalSR_X4_64.png b/tests/images/outputs/64x64/Swin2SR_ClassicalSR_X4_64.png new file mode 100644 index 00000000..c433785c Binary files /dev/null and b/tests/images/outputs/64x64/Swin2SR_ClassicalSR_X4_64.png differ diff --git a/tests/images/outputs/64x64/realesr-general-x4v3.png b/tests/images/outputs/64x64/realesr-general-x4v3.png new file mode 100644 index 00000000..d0b40dc0 Binary files /dev/null and b/tests/images/outputs/64x64/realesr-general-x4v3.png differ diff --git a/tests/images/outputs/64x64/swift_srgan_2x.png b/tests/images/outputs/64x64/swift_srgan_2x.png new file mode 100644 index 00000000..e94a9d25 Binary files /dev/null and b/tests/images/outputs/64x64/swift_srgan_2x.png differ diff --git a/tests/images/outputs/64x64/swift_srgan_4x.png b/tests/images/outputs/64x64/swift_srgan_4x.png new file mode 100644 index 00000000..8011ce0d Binary files /dev/null and b/tests/images/outputs/64x64/swift_srgan_4x.png differ diff --git a/tests/test_Compact.py b/tests/test_Compact.py index c38ce0bd..25c51108 100644 --- a/tests/test_Compact.py +++ b/tests/test_Compact.py @@ -1,7 +1,7 @@ from spandrel import ModelLoader from spandrel.architectures.Compact import SRVGGNetCompact -from .util import ModelFile, disallowed_props +from .util import ModelFile, TestImage, assert_image_inference, disallowed_props def test_Compact_realesr_general_x4v3(snapshot): @@ -11,6 +11,11 @@ def test_Compact_realesr_general_x4v3(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SRVGGNetCompact) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_Compact_community(snapshot): @@ -20,3 +25,8 @@ def test_Compact_community(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SRVGGNetCompact) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) diff --git a/tests/test_ESRGAN.py b/tests/test_ESRGAN.py index 2ce71402..0e87ce5f 100644 --- a/tests/test_ESRGAN.py +++ b/tests/test_ESRGAN.py @@ -1,7 +1,7 @@ from spandrel import ModelLoader from spandrel.architectures.ESRGAN import RRDBNet -from .util import ModelFile, disallowed_props +from .util import ModelFile, TestImage, assert_image_inference, disallowed_props def test_ESRGAN_community(snapshot): @@ -11,6 +11,11 @@ def test_ESRGAN_community(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_BSRGAN(snapshot): @@ -20,6 +25,11 @@ def test_BSRGAN(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_BSRGAN_2x(snapshot): @@ -29,6 +39,11 @@ def test_BSRGAN_2x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_RealSR_DPED(snapshot): @@ -38,6 +53,11 @@ def test_RealSR_DPED(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_RealSR_JPEG(snapshot): @@ -47,6 +67,11 @@ def test_RealSR_JPEG(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_RealESRGAN_x4plus(snapshot): @@ -56,6 +81,11 @@ def test_RealESRGAN_x4plus(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_RealESRGAN_x2plus(snapshot): @@ -65,6 +95,11 @@ def test_RealESRGAN_x2plus(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_RealESRGAN_x4plus_anime_6B(snapshot): @@ -74,6 +109,11 @@ def test_RealESRGAN_x4plus_anime_6B(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_RealESRNet_x4plus(snapshot): @@ -83,3 +123,8 @@ def test_RealESRNet_x4plus(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, RRDBNet) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) diff --git a/tests/test_HAT.py b/tests/test_HAT.py index fc7df015..3891c9a1 100644 --- a/tests/test_HAT.py +++ b/tests/test_HAT.py @@ -1,7 +1,7 @@ from spandrel import ModelLoader from spandrel.architectures.HAT import HAT -from .util import ModelFile, disallowed_props +from .util import ModelFile, TestImage, assert_image_inference, disallowed_props def test_HAT_community1(snapshot): @@ -11,6 +11,11 @@ def test_HAT_community1(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, HAT) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) # TODO: We don't support HAT-S models yet diff --git a/tests/test_OmniSR.py b/tests/test_OmniSR.py index 2fc28c57..4c5d2ca2 100644 --- a/tests/test_OmniSR.py +++ b/tests/test_OmniSR.py @@ -1,7 +1,7 @@ from spandrel import ModelLoader from spandrel.architectures.OmniSR import OmniSR -from .util import ModelFile, disallowed_props +from .util import ModelFile, TestImage, assert_image_inference, disallowed_props def test_OmniSR_community1(snapshot): @@ -11,6 +11,11 @@ def test_OmniSR_community1(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, OmniSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_OmniSR_community2(snapshot): @@ -20,3 +25,8 @@ def test_OmniSR_community2(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, OmniSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) diff --git a/tests/test_SwiftSRGAN.py b/tests/test_SwiftSRGAN.py index fd895beb..a0ac35e6 100644 --- a/tests/test_SwiftSRGAN.py +++ b/tests/test_SwiftSRGAN.py @@ -1,7 +1,7 @@ from spandrel import ModelLoader from spandrel.architectures.SwiftSRGAN import SwiftSRGAN -from .util import ModelFile, disallowed_props +from .util import ModelFile, TestImage, assert_image_inference, disallowed_props def test_SwiftSRGan_2x(snapshot): @@ -11,6 +11,11 @@ def test_SwiftSRGan_2x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SwiftSRGAN) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_SwiftSRGan_4x(snapshot): @@ -20,3 +25,8 @@ def test_SwiftSRGan_4x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SwiftSRGAN) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) diff --git a/tests/test_Swin2SR.py b/tests/test_Swin2SR.py index 380676d8..25d26aef 100644 --- a/tests/test_Swin2SR.py +++ b/tests/test_Swin2SR.py @@ -1,7 +1,7 @@ from spandrel import ModelLoader from spandrel.architectures.Swin2SR import Swin2SR -from .util import ModelFile, disallowed_props +from .util import ModelFile, TestImage, assert_image_inference, disallowed_props def test_Swin2SR_4x(snapshot): @@ -11,3 +11,8 @@ def test_Swin2SR_4x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, Swin2SR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) diff --git a/tests/test_SwinIR.py b/tests/test_SwinIR.py index ba085e78..66f23c11 100644 --- a/tests/test_SwinIR.py +++ b/tests/test_SwinIR.py @@ -1,7 +1,7 @@ from spandrel import ModelLoader from spandrel.architectures.SwinIR import SwinIR -from .util import ModelFile, disallowed_props +from .util import ModelFile, TestImage, assert_image_inference, disallowed_props def test_SwinIR_M_s64w8_2x(snapshot): @@ -11,6 +11,11 @@ def test_SwinIR_M_s64w8_2x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SwinIR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_SwinIR_M_s48w8_4x(snapshot): @@ -20,6 +25,11 @@ def test_SwinIR_M_s48w8_4x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SwinIR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_SwinIR_S_2x(snapshot): @@ -29,6 +39,11 @@ def test_SwinIR_S_2x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SwinIR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) def test_SwinIR_L_4x(snapshot): @@ -38,3 +53,8 @@ def test_SwinIR_L_4x(snapshot): model = ModelLoader().load_from_file(file.path) assert model == snapshot(exclude=disallowed_props) assert isinstance(model.model, SwinIR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) diff --git a/tests/util.py b/tests/util.py index 99e0594d..0aeca667 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,14 +1,22 @@ from __future__ import annotations +import sys from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum from pathlib import Path from urllib.parse import unquote, urlparse from urllib.request import urlretrieve -from syrupy.filters import props # type: ignore +import cv2 +import numpy as np +import torch +from syrupy.filters import props + +from spandrel.__helpers.model_descriptor import ModelDescriptor MODEL_DIR = Path("./tests/models/") +IMAGE_DIR = Path("./tests/images/") def download_model(url: str, name: str | None = None) -> str: @@ -56,3 +64,105 @@ def expect_error(snapshot): if not did_error: raise AssertionError("Expected an error, but none was raised") + + +def read_image(path: str | Path) -> np.ndarray: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + return image + + +def write_image(path: str | Path, image: np.ndarray): + cv2.imwrite(str(path), image) + + +def image_to_tensor(img: np.ndarray) -> torch.Tensor: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.astype(np.float32) / 255.0 + img = np.transpose(img, (2, 0, 1)) + tensor = torch.from_numpy(img) + return tensor.unsqueeze(0) + + +def tensor_to_image(tensor: torch.Tensor) -> np.ndarray: + image = tensor.cpu().squeeze().numpy() + image = np.transpose(image, (1, 2, 0)) + image = np.clip((image * 255.0).round(), 0, 255) + image = image.astype(np.uint8) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + return image + + +def image_inference_tensor( + model: torch.nn.Module, tensor: torch.Tensor +) -> torch.Tensor: + model.eval() + with torch.no_grad(): + return model(tensor) + + +def image_inference(model: torch.nn.Module, image: np.ndarray) -> np.ndarray: + return tensor_to_image(image_inference_tensor(model, image_to_tensor(image))) + + +def get_h_w_c(image: np.ndarray) -> tuple[int, int, int]: + if len(image.shape) == 2: + return image.shape[0], image.shape[1], 1 + return image.shape[0], image.shape[1], image.shape[2] + + +class TestImage(Enum): + SR_16 = "16x16.png" + SR_32 = "32x32.png" + SR_64 = "64x64.png" + + +def assert_image_inference( + model_file: ModelFile, + model: ModelDescriptor, + test_images: list[TestImage], +): + test_images.sort(key=lambda image: image.value) + + update_mode = "--snapshot-update" in sys.argv + + for test_image in test_images: + path = IMAGE_DIR / "inputs" / test_image.value + + image = read_image(path) + image_h, image_w, image_c = get_h_w_c(image) + + assert ( + image_c == model.input_channels + ), f"Expected the input image '{test_image.value}' to have {model.input_channels} channels, but it had {image_c} channels." + + output = image_inference(model.model, image) + output_h, output_w, output_c = get_h_w_c(output) + + assert ( + output_c == model.output_channels + ), f"Expected the output of '{test_image.value}' to have {model.output_channels} channels, but it had {output_c} channels." + assert ( + output_w == image_w * model.scale and output_h == image_h * model.scale + ), f"Expected the input image '{test_image.value}' {image_w}x{image_h} to be scaled {model.scale}x, but the output was {output_w}x{output_h}." + + expected_path = ( + IMAGE_DIR / "outputs" / path.stem / f"{model_file.path.stem}.png" + ) + + if update_mode and not expected_path.exists(): + expected_path.parent.mkdir(exist_ok=True, parents=True) + write_image(expected_path, output) + continue + + assert expected_path.exists(), f"Expected {expected_path} to exist." + expected = read_image(expected_path) + + # Assert that the images are the same within a certain tolerance + # The CI for some reason has a bit of FP precision loss compared to my local machine + # Therefore, a tolerance of 1 is fine enough. + close_enough = np.allclose(output, expected, atol=1) + if update_mode and not close_enough: + write_image(expected_path, output) + continue + + assert close_enough, f"Failed on {test_image.value}"