Skip to content

Commit

Permalink
Rename size -> size_requirements (#7)
Browse files Browse the repository at this point in the history
* Rename `size` -> `size_requirements`

* Rename field in readme

---------

Co-authored-by: Joey Ballentine <[email protected]>
  • Loading branch information
RunDevelopment and joeyballentine authored Nov 18, 2023
1 parent 35dae6f commit 8bad32b
Show file tree
Hide file tree
Showing 29 changed files with 56 additions and 54 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ loaded_model.output_channels

# A SizeRequirements object describing the image size requirements of the model
# i.e the minimum size, the multiple of size, and whether the model requires a square input
loaded_model.size
loaded_model.size_requirements
```

You can also just use this helper class for inference the same way you would with the `model` directly, so for example you could do `result = loaded_model(img)` and it will automatically call the forward method of the model. It also supports moving it to other devices, so you can call `.to` on it just like you would the direct model.
Expand Down
14 changes: 8 additions & 6 deletions src/spandrel/__helpers/model_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
scale: int,
input_channels: int,
output_channels: int,
size: SizeRequirements | None = None,
size_requirements: SizeRequirements | None = None,
):
self.model: T = model
"""
Expand Down Expand Up @@ -115,7 +115,9 @@ def __init__(
The number of output image channels of the model. E.g. 3 for RGB, 1 for grayscale.
"""

self.size: SizeRequirements = size or SizeRequirements()
self.size_requirements: SizeRequirements = (
size_requirements or SizeRequirements()
)
"""
Size requirements for the input image. E.g. minimum size.
"""
Expand Down Expand Up @@ -146,7 +148,7 @@ def __init__(
supports_bfloat16: bool,
input_channels: int,
output_channels: int,
size: SizeRequirements | None = None,
size_requirements: SizeRequirements | None = None,
):
super().__init__(
model,
Expand All @@ -158,7 +160,7 @@ def __init__(
scale=1,
input_channels=input_channels,
output_channels=output_channels,
size=size,
size_requirements=size_requirements,
)


Expand All @@ -173,7 +175,7 @@ def __init__(
supports_bfloat16: bool,
input_channels: int,
output_channels: int,
size: SizeRequirements | None = None,
size_requirements: SizeRequirements | None = None,
):
super().__init__(
model,
Expand All @@ -185,7 +187,7 @@ def __init__(
scale=1,
input_channels=input_channels,
output_channels=output_channels,
size=size,
size_requirements=size_requirements,
)


Expand Down
2 changes: 1 addition & 1 deletion src/spandrel/architectures/CodeFormer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ def load(state_dict: StateDict) -> FaceSRModelDescriptor[CodeFormer]:
scale=8,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/DAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[DAT]:
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/FBCNN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ def load(state_dict: StateDict) -> RestorationModelDescriptor[FBCNN]:
supports_bfloat16=True, # TODO
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16), # TODO
size_requirements=SizeRequirements(minimum=16), # TODO
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/GFPGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def load(state_dict: StateDict) -> FaceSRModelDescriptor[GFPGANv1Clean]:
scale=8,
input_channels=3,
output_channels=3,
size=SizeRequirements(minimum=512),
size_requirements=SizeRequirements(minimum=512),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/HAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[HAT]:
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/LaMa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def load(state_dict: StateDict) -> InpaintModelDescriptor[LaMa]:
supports_bfloat16=True,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/MAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ def load(state_dict: StateDict) -> InpaintModelDescriptor[MAT]:
supports_bfloat16=True,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=512, multiple_of=512, square=True),
size_requirements=SizeRequirements(minimum=512, multiple_of=512, square=True),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/OmniSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[OmniSR]:
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/RestoreFormer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ def load(state_dict: StateDict) -> FaceSRModelDescriptor[RestoreFormer]:
scale=8,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/SCUNet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ def load(state_dict: StateDict) -> RestorationModelDescriptor[SCUNet]:
supports_bfloat16=True,
input_channels=in_nc,
output_channels=in_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/SRFormer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[SRFormer]:
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/Swin2SR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[Swin2SR]:
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion src/spandrel/architectures/SwinIR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,5 @@ def load(state_dict: StateDict) -> SRModelDescriptor[SwinIR]:
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
size=SizeRequirements(minimum=16),
size_requirements=SizeRequirements(minimum=16),
)
2 changes: 1 addition & 1 deletion tests/__snapshots__/test_CodeFormer.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=8,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand Down
4 changes: 2 additions & 2 deletions tests/__snapshots__/test_Compact.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=2,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -20,7 +20,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand Down
18 changes: 9 additions & 9 deletions tests/__snapshots__/test_ESRGAN.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -20,7 +20,7 @@
input_channels=3,
output_channels=3,
scale=2,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -35,7 +35,7 @@
input_channels=3,
output_channels=3,
scale=1,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -50,7 +50,7 @@
input_channels=3,
output_channels=3,
scale=2,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -65,7 +65,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -80,7 +80,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -95,7 +95,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -110,7 +110,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -125,7 +125,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=None, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=None, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand Down
4 changes: 2 additions & 2 deletions tests/__snapshots__/test_FBCNN.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=1,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -18,7 +18,7 @@
input_channels=1,
output_channels=1,
scale=1,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand Down
6 changes: 3 additions & 3 deletions tests/__snapshots__/test_GFPGAN.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=8,
size=SizeRequirements(minimum=512, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=512, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand All @@ -18,7 +18,7 @@
input_channels=3,
output_channels=3,
scale=8,
size=SizeRequirements(minimum=512, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=512, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand All @@ -31,7 +31,7 @@
input_channels=3,
output_channels=3,
scale=8,
size=SizeRequirements(minimum=512, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=512, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand Down
2 changes: 1 addition & 1 deletion tests/__snapshots__/test_HAT.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand Down
2 changes: 1 addition & 1 deletion tests/__snapshots__/test_LaMa.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=4,
output_channels=3,
scale=1,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand Down
2 changes: 1 addition & 1 deletion tests/__snapshots__/test_MAT.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=1,
size=SizeRequirements(minimum=512, multiple_of=512, square=True),
size_requirements=SizeRequirements(minimum=512, multiple_of=512, square=True),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand Down
4 changes: 2 additions & 2 deletions tests/__snapshots__/test_OmniSR.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=2,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -21,7 +21,7 @@
input_channels=3,
output_channels=3,
scale=4,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand Down
2 changes: 1 addition & 1 deletion tests/__snapshots__/test_RestoreFormer.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=8,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
Expand Down
8 changes: 4 additions & 4 deletions tests/__snapshots__/test_SCUNet.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input_channels=3,
output_channels=3,
scale=1,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -18,7 +18,7 @@
input_channels=3,
output_channels=3,
scale=1,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -31,7 +31,7 @@
input_channels=3,
output_channels=3,
scale=1,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand All @@ -44,7 +44,7 @@
input_channels=1,
output_channels=1,
scale=1,
size=SizeRequirements(minimum=16, multiple_of=None, square=False),
size_requirements=SizeRequirements(minimum=16, multiple_of=None, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
Expand Down
Loading

0 comments on commit 8bad32b

Please sign in to comment.