Skip to content

Commit

Permalink
Add ruff's annotations rules + update ruff (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine authored Nov 18, 2023
1 parent 945018f commit f5c2546
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 22 deletions.
24 changes: 13 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dynamic = ["version"]

[project.optional-dependencies]
build = ["setuptools>=46.4.0", "wheel", "build", "twine"]
lint = ["ruff==0.1.4"]
lint = ["ruff==0.1.6"]
typecheck = ["pyright==1.1.335"]
test = ["pytest==7.4.0", "syrupy==4.6.0", "opencv-python==4.8.1.78"]

Expand All @@ -53,13 +53,15 @@ src = ["src"]
# extend-exclude = ["src/architectures/**"]

extend-select = [
"UP", # pyupgrade
"E", # pycodestyle
"W", # pycodestyle
# "F", # pyflakes
"I", # isort
"FA", # flake8-future-annotations
"N", # pep8-naming
"UP", # pyupgrade
"E", # pycodestyle
"W", # pycodestyle
"F", # pyflakes
"I", # isort
"FA", # flake8-future-annotations
"N", # pep8-naming
"ANN001",
"ANN002",
]
ignore = [
"E501", # Line too long
Expand All @@ -68,9 +70,9 @@ ignore = [
]

[tool.ruff.lint.per-file-ignores]
"**/arch/**/*" = ["N"]
"**/__arch_helpers/**/*" = ["N"]
"**/tests/**/*" = ["N802"]
"**/arch/**/*" = ["N", "ANN"]
"**/__arch_helpers/**/*" = ["N", "ANN"]
"**/tests/**/*" = ["N802", "ANN"]

[tool.pytest.ini_options]
filterwarnings = ["ignore::DeprecationWarning", "ignore::UserWarning"]
2 changes: 1 addition & 1 deletion scripts/dump_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load_state(file: str) -> State:
return state_dict


def indent(lines: list[str], indentation=" "):
def indent(lines: list[str], indentation: str = " "):
def do(line: str) -> str:
return "\n".join(indentation + s for s in line.splitlines())

Expand Down
2 changes: 1 addition & 1 deletion src/spandrel/__helpers/unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
def find_class(self, module: str, name: str):
# Only allow required classes to load state dict
if (module, name) not in safe_list:
raise pickle.UnpicklingError(f"Global '{module}.{name}' is forbidden")
Expand Down
8 changes: 5 additions & 3 deletions src/spandrel/architectures/Compact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math

from ...__helpers.model_descriptor import SRModelDescriptor, StateDict
Expand All @@ -8,15 +10,15 @@ def _get_num_conv(highest_num: int) -> int:
return (highest_num - 2) // 2


def _get_num_feats(state, weight_keys) -> int:
def _get_num_feats(state: StateDict, weight_keys: list[str]) -> int:
return state[weight_keys[0]].shape[0]


def _get_in_nc(state, weight_keys) -> int:
def _get_in_nc(state: StateDict, weight_keys: list[str]) -> int:
return state[weight_keys[0]].shape[1]


def _get_scale(pixelshuffle_shape, out_nc) -> int:
def _get_scale(pixelshuffle_shape: int, out_nc: int) -> int:
scale = math.sqrt(pixelshuffle_shape / out_nc)
if scale - int(scale) > 0:
print(
Expand Down
8 changes: 4 additions & 4 deletions src/spandrel/architectures/ESRGAN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .arch.RRDB import RRDBNet


def _new_to_old_arch(state, state_map, num_blocks):
def _new_to_old_arch(state: StateDict, state_map: dict, num_blocks: int):
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
if "params_ema" in state:
state = state["params_ema"]
Expand Down Expand Up @@ -56,7 +56,7 @@ def _new_to_old_arch(state, state_map, num_blocks):
old_state[f"model.{max_upconv + 4}.bias"] = state[key]

# Sort by first numeric value of each layer
def compare(item1, item2):
def compare(item1: str, item2: str):
parts1 = item1.split(".")
parts2 = item2.split(".")
int1 = int(parts1[1])
Expand All @@ -71,7 +71,7 @@ def compare(item1, item2):
return out_dict


def _get_scale(state, min_part: int = 6) -> int:
def _get_scale(state: StateDict, min_part: int = 6) -> int:
n = 0
for part in list(state):
parts = part.split(".")[1:]
Expand All @@ -82,7 +82,7 @@ def _get_scale(state, min_part: int = 6) -> int:
return 2**n


def _get_num_blocks(state, state_map) -> int:
def _get_num_blocks(state: StateDict, state_map: dict) -> int:
nbs = []
state_keys = state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
Expand Down
4 changes: 2 additions & 2 deletions src/spandrel/architectures/SPSR/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .arch.SPSR import SPSRNet as SPSR


def get_scale(state, min_part: int = 4) -> int:
def get_scale(state: StateDict, min_part: int = 4) -> int:
n = 0
for part in list(state):
parts = part.split(".")
Expand All @@ -13,7 +13,7 @@ def get_scale(state, min_part: int = 4) -> int:
return 2**n


def get_num_blocks(state) -> int:
def get_num_blocks(state: StateDict) -> int:
nb = 0
for part in list(state):
parts = part.split(".")
Expand Down

0 comments on commit f5c2546

Please sign in to comment.