From d19b7cdae9b12e580ab03f0bfcbf26abe286b201 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jul 2024 09:22:41 +0300 Subject: [PATCH 1/4] Registry: raise a specific error on duplicate architectures --- libs/spandrel/spandrel/__helpers/registry.py | 10 +++++++++- libs/spandrel/spandrel/__init__.py | 8 +++++++- tests/__snapshots__/test_registry.ambr | 4 ++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/libs/spandrel/spandrel/__helpers/registry.py b/libs/spandrel/spandrel/__helpers/registry.py index 939ee36b..828457b2 100644 --- a/libs/spandrel/spandrel/__helpers/registry.py +++ b/libs/spandrel/spandrel/__helpers/registry.py @@ -15,6 +15,12 @@ class UnsupportedModelError(Exception): """ +class DuplicateArchitectureError(ValueError): + """ + An error that will be thrown by `ArchRegistry` if the same architecture is added twice. + """ + + @dataclass(frozen=True) class ArchSupport: """ @@ -119,7 +125,9 @@ def add(self, *architectures: ArchSupport): new_by_id = dict(self._by_id) for arch in architectures: if arch.architecture.id in new_by_id: - raise ValueError(f"Duplicate architecture: {arch.architecture.id}") + raise DuplicateArchitectureError( + f"Duplicate architecture: {arch.architecture.id}" + ) new_architectures.append(arch) new_by_id[arch.architecture.id] = arch diff --git a/libs/spandrel/spandrel/__init__.py b/libs/spandrel/spandrel/__init__.py index 0a73ed39..6efff150 100644 --- a/libs/spandrel/spandrel/__init__.py +++ b/libs/spandrel/spandrel/__init__.py @@ -20,7 +20,12 @@ StateDict, UnsupportedDtypeError, ) -from .__helpers.registry import ArchRegistry, ArchSupport, UnsupportedModelError +from .__helpers.registry import ( + ArchRegistry, + ArchSupport, + DuplicateArchitectureError, + UnsupportedModelError, +) __all__ = [ "ArchId", @@ -28,6 +33,7 @@ "ArchRegistry", "ArchSupport", "canonicalize_state_dict", + "DuplicateArchitectureError", "ImageModelDescriptor", "MAIN_REGISTRY", "MaskedImageModelDescriptor", diff --git a/tests/__snapshots__/test_registry.ambr b/tests/__snapshots__/test_registry.ambr index 9a63484e..b99afb94 100644 --- a/tests/__snapshots__/test_registry.ambr +++ b/tests/__snapshots__/test_registry.ambr @@ -1,9 +1,9 @@ # serializer version: 1 # name: test_registry_add_invalid - ValueError('Duplicate architecture: b') + DuplicateArchitectureError('Duplicate architecture: b') # --- # name: test_registry_add_invalid.1 - ValueError('Duplicate architecture: test') + DuplicateArchitectureError('Duplicate architecture: test') # --- # name: test_registry_add_invalid.2 ValueError('Circular dependency in architecture detection: 1 -> 2 -> 1') From 400e51708977f93e42b80a7e1fa1ef2e721df495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 Jul 2024 08:44:05 +0300 Subject: [PATCH 2/4] Add `ignore_duplicates` to ArchRegistry --- libs/spandrel/spandrel/__helpers/registry.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/libs/spandrel/spandrel/__helpers/registry.py b/libs/spandrel/spandrel/__helpers/registry.py index 828457b2..dfdc461e 100644 --- a/libs/spandrel/spandrel/__helpers/registry.py +++ b/libs/spandrel/spandrel/__helpers/registry.py @@ -111,32 +111,45 @@ def architectures( else: raise ValueError(f"Invalid order: {order}") - def add(self, *architectures: ArchSupport): + def add( + self, + *architectures: ArchSupport, + ignore_duplicates: bool = False, + ) -> list[ArchSupport]: """ Adds the given architectures to the registry. - Throws an error if an architecture with the same ID already exists. + Throws an error if an architecture with the same ID already exists, + unless `ignore_duplicates` is True, in which case the old architecture is retained. + Throws an error if a circular dependency of `before` references is detected. If an error is thrown, the registry is left unchanged. + + Returns a list of architectures that were added. """ new_architectures = list(self._architectures) new_by_id = dict(self._by_id) + added = [] for arch in architectures: if arch.architecture.id in new_by_id: + if ignore_duplicates: + continue raise DuplicateArchitectureError( f"Duplicate architecture: {arch.architecture.id}" ) new_architectures.append(arch) new_by_id[arch.architecture.id] = arch + added.append(arch) new_ordered = ArchRegistry._get_ordered(new_architectures) self._architectures = new_architectures self._ordered = new_ordered self._by_id = new_by_id + return added @staticmethod def _get_ordered(architectures: list[ArchSupport]) -> list[ArchSupport]: From e7e41269115ec00bff4a38c0a9808f1838069c5c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jul 2024 09:24:06 +0300 Subject: [PATCH 3/4] Add a convenience `install()` function for extra architectures --- README.md | 9 ++++---- libs/spandrel_extra_arches/README.md | 8 +++---- .../spandrel_extra_arches/__helper.py | 23 ++++++++++++++++++- .../spandrel_extra_arches/__init__.py | 7 ++++-- scripts/dump_state_dict.py | 6 ++--- tests/util.py | 9 +++++--- 6 files changed, 44 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index d3c78828..26115c59 100644 --- a/README.md +++ b/README.md @@ -56,15 +56,14 @@ Note that `model` is a [`ModelDescriptor`](https://chainner.app/spandrel/#ModelD If you are working on a non-commercial open-source project or a private project, you should use `spandrel` and `spandrel_extra_arches` to get everything spandrel has to offer. The `spandrel` package only contains architectures with [permissive and public domain licenses](https://en.wikipedia.org/wiki/Permissive_software_license) (MIT, Apache 2.0, public domain), so it is fit for every use case. Architectures with restrictive licenses (e.g. non-commercial) are implemented in the `spandrel_extra_arches` package. ```python -from spandrel import ImageModelDescriptor, MAIN_REGISTRY, ModelLoader -from spandrel_extra_arches import EXTRA_REGISTRY -import torch +import spandrel +import spandrel_extra_arches # add extra architectures before `ModelLoader` is used -MAIN_REGISTRY.add(*EXTRA_REGISTRY) +spandrel_extra_arches.install() # load a model from disk -model = ModelLoader().load_from_file(r"path/to/model.pth") +model = spandrel.ModelLoader().load_from_file(r"path/to/model.pth") ... # use model ``` diff --git a/libs/spandrel_extra_arches/README.md b/libs/spandrel_extra_arches/README.md index 7d763dae..b41943be 100644 --- a/libs/spandrel_extra_arches/README.md +++ b/libs/spandrel_extra_arches/README.md @@ -20,14 +20,14 @@ pip install spandrel spandrel_extra_arches ## Basic usage ```python -from spandrel import MAIN_REGISTRY, ModelLoader -from spandrel_extra_arches import EXTRA_REGISTRY +import spandrel +import spandrel_extra_arches # add extra architectures before `ModelLoader` is used -MAIN_REGISTRY.add(*EXTRA_REGISTRY) +spandrel_extra_arches.install() # load a model from disk -model = ModelLoader().load_from_file(r"path/to/model.pth") +model = spandrel.ModelLoader().load_from_file(r"path/to/model.pth") ... # use model ``` diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py index d759ec6b..42f8b670 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py @@ -1,4 +1,8 @@ -from spandrel import ArchRegistry, ArchSupport +from spandrel import ( + MAIN_REGISTRY, + ArchRegistry, + ArchSupport, +) from .architectures import ( MAT, @@ -27,3 +31,20 @@ ArchSupport.from_architecture(MPRNet.MPRNetArch()), ArchSupport.from_architecture(MIRNet2.MIRNet2Arch()), ) + + +def install(*, ignore_duplicates: bool = False) -> list: + """ + Try to install the extra architectures into the main registry. + + If `ignore_duplicates` is True, the function will not raise an error + if the installation fails due to any of the architectures having already + been installed (but they won't be replaced by ones from this package). + """ + return MAIN_REGISTRY.add(*EXTRA_REGISTRY, ignore_duplicates=ignore_duplicates) + + +__all__ = [ + "EXTRA_REGISTRY", + "install", +] diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py index 2a639290..66764019 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py @@ -1,5 +1,8 @@ -from .__helper import EXTRA_REGISTRY +from .__helper import EXTRA_REGISTRY, install __version__ = "0.1.1" -__all__ = ["EXTRA_REGISTRY"] +__all__ = [ + "EXTRA_REGISTRY", + "install", +] diff --git a/scripts/dump_state_dict.py b/scripts/dump_state_dict.py index 9883cfaa..b0b9982b 100644 --- a/scripts/dump_state_dict.py +++ b/scripts/dump_state_dict.py @@ -54,14 +54,14 @@ sys.path.insert(0, __file__ + "/../../libs/spandrel") sys.path.insert(0, __file__ + "/../../libs/spandrel_extra_arches") - from spandrel import MAIN_REGISTRY, ModelLoader # noqa: E402 - from spandrel_extra_arches import EXTRA_REGISTRY # noqa: E402 + import spandrel_extra_arches # noqa: E402 + from spandrel import ModelLoader # noqa: E402 except ImportError: print("Unable to import spandrel.") print("Follow the contributing guide to set up editable installs.") raise -MAIN_REGISTRY.add(*EXTRA_REGISTRY) +spandrel_extra_arches.install() State = Dict[str, object] diff --git a/tests/util.py b/tests/util.py index 21a3b566..2143b5f7 100644 --- a/tests/util.py +++ b/tests/util.py @@ -29,8 +29,8 @@ from bs4 import BeautifulSoup, Tag from syrupy.filters import props +import spandrel_extra_arches from spandrel import ( - MAIN_REGISTRY, Architecture, ImageModelDescriptor, ModelDescriptor, @@ -40,9 +40,12 @@ ) from spandrel.__helpers.model_descriptor import StateDict from spandrel.util import KeyCondition -from spandrel_extra_arches import EXTRA_REGISTRY -MAIN_REGISTRY.add(*EXTRA_REGISTRY) +# The asserts check that the install function first does return +# the newly installed architectures and then does not return them +# when requested to ignore duplicates. +assert spandrel_extra_arches.install() +assert not spandrel_extra_arches.install(ignore_duplicates=True) TEST_DIR = Path("./tests/").resolve() MODEL_DIR = TEST_DIR / "models" From 11c86dc8234be3dbf87e093649a006cf73a5ae73 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 Jul 2024 13:27:01 +0300 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Michael Schmidt --- libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py | 4 +++- tests/util.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py index 42f8b670..fa825dad 100644 --- a/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py +++ b/libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from spandrel import ( MAIN_REGISTRY, ArchRegistry, @@ -33,7 +35,7 @@ ) -def install(*, ignore_duplicates: bool = False) -> list: +def install(*, ignore_duplicates: bool = False) -> list[ArchSupport]: """ Try to install the extra architectures into the main registry. diff --git a/tests/util.py b/tests/util.py index 2143b5f7..72e40f76 100644 --- a/tests/util.py +++ b/tests/util.py @@ -44,8 +44,9 @@ # The asserts check that the install function first does return # the newly installed architectures and then does not return them # when requested to ignore duplicates. -assert spandrel_extra_arches.install() -assert not spandrel_extra_arches.install(ignore_duplicates=True) +_installed_extras = spandrel_extra_arches.install() +assert len(_installed_extras) > 0 +assert len(spandrel_extra_arches.install(ignore_duplicates=True)) == 0 TEST_DIR = Path("./tests/").resolve() MODEL_DIR = TEST_DIR / "models"