From d541966ef5bf02e6d5a5b3718d5884d3519eebc4 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Thu, 11 Jul 2024 10:10:35 +0200 Subject: [PATCH 1/2] Fixed input channels and size req of LaMa --- .../spandrel/architectures/LaMa/__init__.py | 4 ++-- tests/__snapshots__/test_LaMa.ambr | 4 ++-- tests/test_LaMa.py | 15 ++++++++++++++- tests/util.py | 13 +++++++++---- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/libs/spandrel/spandrel/architectures/LaMa/__init__.py b/libs/spandrel/spandrel/architectures/LaMa/__init__.py index b3008936..56abc2d0 100644 --- a/libs/spandrel/spandrel/architectures/LaMa/__init__.py +++ b/libs/spandrel/spandrel/architectures/LaMa/__init__.py @@ -49,7 +49,7 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[LaMa]: tags=[], supports_half=False, supports_bfloat16=True, - input_channels=in_nc, + input_channels=in_nc - 1, output_channels=out_nc, - size_requirements=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16, multiple_of=8), ) diff --git a/tests/__snapshots__/test_LaMa.ambr b/tests/__snapshots__/test_LaMa.ambr index 73c51f03..84d4e48f 100644 --- a/tests/__snapshots__/test_LaMa.ambr +++ b/tests/__snapshots__/test_LaMa.ambr @@ -5,11 +5,11 @@ id='LaMa', name='LaMa', ), - input_channels=4, + input_channels=3, output_channels=3, purpose='Inpainting', scale=1, - size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=8, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/test_LaMa.py b/tests/test_LaMa.py index 2812353d..1f76d048 100644 --- a/tests/test_LaMa.py +++ b/tests/test_LaMa.py @@ -1,6 +1,12 @@ from spandrel.architectures.LaMa import LaMa, LaMaArch -from .util import ModelFile, assert_loads_correctly, disallowed_props, skip_if_unchanged +from .util import ( + ModelFile, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, + skip_if_unchanged, +) skip_if_unchanged(__file__) @@ -15,6 +21,13 @@ def test_load(): ) +def test_size_requirements(): + file = ModelFile.from_url( + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt" + ) + assert_size_requirements(file.load_model()) + + def test_LaMa(snapshot): file = ModelFile.from_url( "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt" diff --git a/tests/util.py b/tests/util.py index 21a3b566..0e659f38 100644 --- a/tests/util.py +++ b/tests/util.py @@ -581,17 +581,21 @@ def assert_size_requirements( max_size: int = 64, max_candidates: int = 8, ) -> None: - assert isinstance(model, ImageModelDescriptor) - device = get_test_device() def test_size(width: int, height: int) -> None: try: - input_tensor = torch.rand(1, model.input_channels, height, width) + input_tensor = torch.rand( + 1, model.input_channels, height, width, device=device + ) model.to(device).eval() with torch.no_grad(): - output_tensor = model(input_tensor.to(device)) + if isinstance(model, ImageModelDescriptor): + output_tensor = model(input_tensor) + else: + mask = torch.rand(1, 1, height, width, device=device).round_() + output_tensor = model(input_tensor, mask) expected_shape = ( 1, @@ -603,6 +607,7 @@ def test_size(width: int, height: int) -> None: output_tensor.shape == expected_shape ), f"Expected {expected_shape}, but got {output_tensor.shape}" except Exception as e: + print(str(e)) raise AssertionError( f"Failed size requirement test for {width=} {height=}" ) from e From 11c9cf6e8c1530e41698636c22d6fb224fe42636 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Thu, 11 Jul 2024 10:56:24 +0200 Subject: [PATCH 2/2] Remove debug print statement --- tests/util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/util.py b/tests/util.py index 0e659f38..5540c3d3 100644 --- a/tests/util.py +++ b/tests/util.py @@ -607,7 +607,6 @@ def test_size(width: int, height: int) -> None: output_tensor.shape == expected_shape ), f"Expected {expected_shape}, but got {output_tensor.shape}" except Exception as e: - print(str(e)) raise AssertionError( f"Failed size requirement test for {width=} {height=}" ) from e