diff --git a/libs/spandrel/spandrel/architectures/LaMa/__init__.py b/libs/spandrel/spandrel/architectures/LaMa/__init__.py index b3008936..56abc2d0 100644 --- a/libs/spandrel/spandrel/architectures/LaMa/__init__.py +++ b/libs/spandrel/spandrel/architectures/LaMa/__init__.py @@ -49,7 +49,7 @@ def load(self, state_dict: StateDict) -> MaskedImageModelDescriptor[LaMa]: tags=[], supports_half=False, supports_bfloat16=True, - input_channels=in_nc, + input_channels=in_nc - 1, output_channels=out_nc, - size_requirements=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16, multiple_of=8), ) diff --git a/tests/__snapshots__/test_LaMa.ambr b/tests/__snapshots__/test_LaMa.ambr index 73c51f03..84d4e48f 100644 --- a/tests/__snapshots__/test_LaMa.ambr +++ b/tests/__snapshots__/test_LaMa.ambr @@ -5,11 +5,11 @@ id='LaMa', name='LaMa', ), - input_channels=4, + input_channels=3, output_channels=3, purpose='Inpainting', scale=1, - size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=8, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/test_LaMa.py b/tests/test_LaMa.py index 2812353d..1f76d048 100644 --- a/tests/test_LaMa.py +++ b/tests/test_LaMa.py @@ -1,6 +1,12 @@ from spandrel.architectures.LaMa import LaMa, LaMaArch -from .util import ModelFile, assert_loads_correctly, disallowed_props, skip_if_unchanged +from .util import ( + ModelFile, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, + skip_if_unchanged, +) skip_if_unchanged(__file__) @@ -15,6 +21,13 @@ def test_load(): ) +def test_size_requirements(): + file = ModelFile.from_url( + "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt" + ) + assert_size_requirements(file.load_model()) + + def test_LaMa(snapshot): file = ModelFile.from_url( "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt" diff --git a/tests/util.py b/tests/util.py index 21a3b566..5540c3d3 100644 --- a/tests/util.py +++ b/tests/util.py @@ -581,17 +581,21 @@ def assert_size_requirements( max_size: int = 64, max_candidates: int = 8, ) -> None: - assert isinstance(model, ImageModelDescriptor) - device = get_test_device() def test_size(width: int, height: int) -> None: try: - input_tensor = torch.rand(1, model.input_channels, height, width) + input_tensor = torch.rand( + 1, model.input_channels, height, width, device=device + ) model.to(device).eval() with torch.no_grad(): - output_tensor = model(input_tensor.to(device)) + if isinstance(model, ImageModelDescriptor): + output_tensor = model(input_tensor) + else: + mask = torch.rand(1, 1, height, width, device=device).round_() + output_tensor = model(input_tensor, mask) expected_shape = ( 1,