diff --git a/README.md b/README.md index 79a0d2cc..3535174b 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ loaded_model.output_channels # A SizeRequirements object describing the image size requirements of the model # i.e the minimum size, the multiple of size, and whether the model requires a square input -loaded_model.size +loaded_model.size_requirements ``` You can also just use this helper class for inference the same way you would with the `model` directly, so for example you could do `result = loaded_model(img)` and it will automatically call the forward method of the model. It also supports moving it to other devices, so you can call `.to` on it just like you would the direct model. diff --git a/src/spandrel/__helpers/model_descriptor.py b/src/spandrel/__helpers/model_descriptor.py index c8126248..44b46a84 100644 --- a/src/spandrel/__helpers/model_descriptor.py +++ b/src/spandrel/__helpers/model_descriptor.py @@ -66,7 +66,7 @@ def __init__( scale: int, input_channels: int, output_channels: int, - size: SizeRequirements | None = None, + size_requirements: SizeRequirements | None = None, ): self.model: T = model """ @@ -115,7 +115,9 @@ def __init__( The number of output image channels of the model. E.g. 3 for RGB, 1 for grayscale. """ - self.size: SizeRequirements = size or SizeRequirements() + self.size_requirements: SizeRequirements = ( + size_requirements or SizeRequirements() + ) """ Size requirements for the input image. E.g. minimum size. """ @@ -146,7 +148,7 @@ def __init__( supports_bfloat16: bool, input_channels: int, output_channels: int, - size: SizeRequirements | None = None, + size_requirements: SizeRequirements | None = None, ): super().__init__( model, @@ -158,7 +160,7 @@ def __init__( scale=1, input_channels=input_channels, output_channels=output_channels, - size=size, + size_requirements=size_requirements, ) @@ -173,7 +175,7 @@ def __init__( supports_bfloat16: bool, input_channels: int, output_channels: int, - size: SizeRequirements | None = None, + size_requirements: SizeRequirements | None = None, ): super().__init__( model, @@ -185,7 +187,7 @@ def __init__( scale=1, input_channels=input_channels, output_channels=output_channels, - size=size, + size_requirements=size_requirements, ) diff --git a/src/spandrel/architectures/CodeFormer/__init__.py b/src/spandrel/architectures/CodeFormer/__init__.py index 39c0a85c..752b8f89 100644 --- a/src/spandrel/architectures/CodeFormer/__init__.py +++ b/src/spandrel/architectures/CodeFormer/__init__.py @@ -58,5 +58,5 @@ def load(state_dict: StateDict) -> FaceSRModelDescriptor[CodeFormer]: scale=8, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/DAT/__init__.py b/src/spandrel/architectures/DAT/__init__.py index a568fb3b..75a192a5 100644 --- a/src/spandrel/architectures/DAT/__init__.py +++ b/src/spandrel/architectures/DAT/__init__.py @@ -157,5 +157,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[DAT]: scale=scale, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/FBCNN/__init__.py b/src/spandrel/architectures/FBCNN/__init__.py index 72990b39..044c8e97 100644 --- a/src/spandrel/architectures/FBCNN/__init__.py +++ b/src/spandrel/architectures/FBCNN/__init__.py @@ -47,5 +47,5 @@ def load(state_dict: StateDict) -> RestorationModelDescriptor[FBCNN]: supports_bfloat16=True, # TODO input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), # TODO + size_requirements=SizeRequirements(minimum=16), # TODO ) diff --git a/src/spandrel/architectures/GFPGAN/__init__.py b/src/spandrel/architectures/GFPGAN/__init__.py index 8e87c70a..2d8a5172 100644 --- a/src/spandrel/architectures/GFPGAN/__init__.py +++ b/src/spandrel/architectures/GFPGAN/__init__.py @@ -45,5 +45,5 @@ def load(state_dict: StateDict) -> FaceSRModelDescriptor[GFPGANv1Clean]: scale=8, input_channels=3, output_channels=3, - size=SizeRequirements(minimum=512), + size_requirements=SizeRequirements(minimum=512), ) diff --git a/src/spandrel/architectures/HAT/__init__.py b/src/spandrel/architectures/HAT/__init__.py index 633315ce..c2798293 100644 --- a/src/spandrel/architectures/HAT/__init__.py +++ b/src/spandrel/architectures/HAT/__init__.py @@ -181,5 +181,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[HAT]: scale=scale, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/LaMa/__init__.py b/src/spandrel/architectures/LaMa/__init__.py index a2b66fb3..fec6dbb3 100644 --- a/src/spandrel/architectures/LaMa/__init__.py +++ b/src/spandrel/architectures/LaMa/__init__.py @@ -28,5 +28,5 @@ def load(state_dict: StateDict) -> InpaintModelDescriptor[LaMa]: supports_bfloat16=True, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/MAT/__init__.py b/src/spandrel/architectures/MAT/__init__.py index 66e5a100..0f3cee8c 100644 --- a/src/spandrel/architectures/MAT/__init__.py +++ b/src/spandrel/architectures/MAT/__init__.py @@ -26,5 +26,5 @@ def load(state_dict: StateDict) -> InpaintModelDescriptor[MAT]: supports_bfloat16=True, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=512, multiple_of=512, square=True), + size_requirements=SizeRequirements(minimum=512, multiple_of=512, square=True), ) diff --git a/src/spandrel/architectures/OmniSR/__init__.py b/src/spandrel/architectures/OmniSR/__init__.py index 0db22ad5..2c864d00 100644 --- a/src/spandrel/architectures/OmniSR/__init__.py +++ b/src/spandrel/architectures/OmniSR/__init__.py @@ -77,5 +77,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]: scale=scale, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/RestoreFormer/__init__.py b/src/spandrel/architectures/RestoreFormer/__init__.py index 860b5db8..321f7bd9 100644 --- a/src/spandrel/architectures/RestoreFormer/__init__.py +++ b/src/spandrel/architectures/RestoreFormer/__init__.py @@ -25,5 +25,5 @@ def load(state_dict: StateDict) -> FaceSRModelDescriptor[RestoreFormer]: scale=8, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/SCUNet/__init__.py b/src/spandrel/architectures/SCUNet/__init__.py index c622b609..cb7755b0 100644 --- a/src/spandrel/architectures/SCUNet/__init__.py +++ b/src/spandrel/architectures/SCUNet/__init__.py @@ -33,5 +33,5 @@ def load(state_dict: StateDict) -> RestorationModelDescriptor[SCUNet]: supports_bfloat16=True, input_channels=in_nc, output_channels=in_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/SRFormer/__init__.py b/src/spandrel/architectures/SRFormer/__init__.py index 398d3d5d..7490b53c 100644 --- a/src/spandrel/architectures/SRFormer/__init__.py +++ b/src/spandrel/architectures/SRFormer/__init__.py @@ -194,5 +194,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[SRFormer]: scale=scale, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/Swin2SR/__init__.py b/src/spandrel/architectures/Swin2SR/__init__.py index e559f586..a6c56774 100644 --- a/src/spandrel/architectures/Swin2SR/__init__.py +++ b/src/spandrel/architectures/Swin2SR/__init__.py @@ -197,5 +197,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]: scale=scale, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/src/spandrel/architectures/SwinIR/__init__.py b/src/spandrel/architectures/SwinIR/__init__.py index 9f7df070..8a89fccd 100644 --- a/src/spandrel/architectures/SwinIR/__init__.py +++ b/src/spandrel/architectures/SwinIR/__init__.py @@ -206,5 +206,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[SwinIR]: scale=scale, input_channels=in_nc, output_channels=out_nc, - size=SizeRequirements(minimum=16), + size_requirements=SizeRequirements(minimum=16), ) diff --git a/tests/__snapshots__/test_CodeFormer.ambr b/tests/__snapshots__/test_CodeFormer.ambr index f8f12fc9..848df1a7 100644 --- a/tests/__snapshots__/test_CodeFormer.ambr +++ b/tests/__snapshots__/test_CodeFormer.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=8, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/__snapshots__/test_Compact.ambr b/tests/__snapshots__/test_Compact.ambr index f3018295..c32cc2a8 100644 --- a/tests/__snapshots__/test_Compact.ambr +++ b/tests/__snapshots__/test_Compact.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=2, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -20,7 +20,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ diff --git a/tests/__snapshots__/test_ESRGAN.ambr b/tests/__snapshots__/test_ESRGAN.ambr index 2fd4799a..dc5b8b86 100644 --- a/tests/__snapshots__/test_ESRGAN.ambr +++ b/tests/__snapshots__/test_ESRGAN.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -20,7 +20,7 @@ input_channels=3, output_channels=3, scale=2, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -35,7 +35,7 @@ input_channels=3, output_channels=3, scale=1, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -50,7 +50,7 @@ input_channels=3, output_channels=3, scale=2, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -65,7 +65,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -80,7 +80,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -95,7 +95,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -110,7 +110,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -125,7 +125,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ diff --git a/tests/__snapshots__/test_FBCNN.ambr b/tests/__snapshots__/test_FBCNN.ambr index 06cfa4b1..0802763a 100644 --- a/tests/__snapshots__/test_FBCNN.ambr +++ b/tests/__snapshots__/test_FBCNN.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=1, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -18,7 +18,7 @@ input_channels=1, output_channels=1, scale=1, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ diff --git a/tests/__snapshots__/test_GFPGAN.ambr b/tests/__snapshots__/test_GFPGAN.ambr index e8be00fa..c09a1482 100644 --- a/tests/__snapshots__/test_GFPGAN.ambr +++ b/tests/__snapshots__/test_GFPGAN.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=8, - size=SizeRequirements(minimum=512, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=512, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -18,7 +18,7 @@ input_channels=3, output_channels=3, scale=8, - size=SizeRequirements(minimum=512, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=512, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -31,7 +31,7 @@ input_channels=3, output_channels=3, scale=8, - size=SizeRequirements(minimum=512, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=512, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/__snapshots__/test_HAT.ambr b/tests/__snapshots__/test_HAT.ambr index 4076093b..500f1e72 100644 --- a/tests/__snapshots__/test_HAT.ambr +++ b/tests/__snapshots__/test_HAT.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/__snapshots__/test_LaMa.ambr b/tests/__snapshots__/test_LaMa.ambr index e9a369e2..971362fe 100644 --- a/tests/__snapshots__/test_LaMa.ambr +++ b/tests/__snapshots__/test_LaMa.ambr @@ -5,7 +5,7 @@ input_channels=4, output_channels=3, scale=1, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/__snapshots__/test_MAT.ambr b/tests/__snapshots__/test_MAT.ambr index 0f86b8a8..0237b9fe 100644 --- a/tests/__snapshots__/test_MAT.ambr +++ b/tests/__snapshots__/test_MAT.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=1, - size=SizeRequirements(minimum=512, multiple_of=512, square=True), + size_requirements=SizeRequirements(minimum=512, multiple_of=512, square=True), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/__snapshots__/test_OmniSR.ambr b/tests/__snapshots__/test_OmniSR.ambr index 2e3110ed..9bacf011 100644 --- a/tests/__snapshots__/test_OmniSR.ambr +++ b/tests/__snapshots__/test_OmniSR.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=2, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -21,7 +21,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ diff --git a/tests/__snapshots__/test_RestoreFormer.ambr b/tests/__snapshots__/test_RestoreFormer.ambr index aa5755a0..f870c37e 100644 --- a/tests/__snapshots__/test_RestoreFormer.ambr +++ b/tests/__snapshots__/test_RestoreFormer.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=8, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/__snapshots__/test_SCUNet.ambr b/tests/__snapshots__/test_SCUNet.ambr index fa0f0472..6f2a81eb 100644 --- a/tests/__snapshots__/test_SCUNet.ambr +++ b/tests/__snapshots__/test_SCUNet.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=1, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -18,7 +18,7 @@ input_channels=3, output_channels=3, scale=1, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -31,7 +31,7 @@ input_channels=3, output_channels=3, scale=1, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -44,7 +44,7 @@ input_channels=1, output_channels=1, scale=1, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ diff --git a/tests/__snapshots__/test_SwiftSRGAN.ambr b/tests/__snapshots__/test_SwiftSRGAN.ambr index 83705e89..abaf0030 100644 --- a/tests/__snapshots__/test_SwiftSRGAN.ambr +++ b/tests/__snapshots__/test_SwiftSRGAN.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=2, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ @@ -20,7 +20,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=None, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False), supports_bfloat16=True, supports_half=True, tags=list([ diff --git a/tests/__snapshots__/test_Swin2SR.ambr b/tests/__snapshots__/test_Swin2SR.ambr index 84954008..2bab4c11 100644 --- a/tests/__snapshots__/test_Swin2SR.ambr +++ b/tests/__snapshots__/test_Swin2SR.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ diff --git a/tests/__snapshots__/test_SwinIR.ambr b/tests/__snapshots__/test_SwinIR.ambr index 9fc8c43d..7bfa3872 100644 --- a/tests/__snapshots__/test_SwinIR.ambr +++ b/tests/__snapshots__/test_SwinIR.ambr @@ -5,7 +5,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -23,7 +23,7 @@ input_channels=3, output_channels=3, scale=4, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -41,7 +41,7 @@ input_channels=3, output_channels=3, scale=2, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([ @@ -59,7 +59,7 @@ input_channels=3, output_channels=3, scale=2, - size=SizeRequirements(minimum=16, multiple_of=None, square=False), + size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False), supports_bfloat16=True, supports_half=False, tags=list([