Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve initialization for extra arches #281

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ Note that `model` is a [`ModelDescriptor`](https://chainner.app/spandrel/#ModelD
If you are working on a non-commercial open-source project or a private project, you should use `spandrel` and `spandrel_extra_arches` to get everything spandrel has to offer. The `spandrel` package only contains architectures with [permissive and public domain licenses](https://en.wikipedia.org/wiki/Permissive_software_license) (MIT, Apache 2.0, public domain), so it is fit for every use case. Architectures with restrictive licenses (e.g. non-commercial) are implemented in the `spandrel_extra_arches` package.

```python
from spandrel import ImageModelDescriptor, MAIN_REGISTRY, ModelLoader
from spandrel_extra_arches import EXTRA_REGISTRY
import torch
import spandrel
import spandrel_extra_arches

# add extra architectures before `ModelLoader` is used
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
spandrel_extra_arches.install()

# load a model from disk
model = ModelLoader().load_from_file(r"path/to/model.pth")
model = spandrel.ModelLoader().load_from_file(r"path/to/model.pth")

... # use model
```
Expand Down
27 changes: 24 additions & 3 deletions libs/spandrel/spandrel/__helpers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ class UnsupportedModelError(Exception):
"""


class DuplicateArchitectureError(ValueError):
"""
An error that will be thrown by `ArchRegistry` if the same architecture is added twice.
"""


@dataclass(frozen=True)
class ArchSupport:
"""
Expand Down Expand Up @@ -105,30 +111,45 @@ def architectures(
else:
raise ValueError(f"Invalid order: {order}")

def add(self, *architectures: ArchSupport):
def add(
self,
*architectures: ArchSupport,
ignore_duplicates: bool = False,
) -> list[ArchSupport]:
"""
Adds the given architectures to the registry.

Throws an error if an architecture with the same ID already exists.
Throws an error if an architecture with the same ID already exists,
unless `ignore_duplicates` is True, in which case the old architecture is retained.

Throws an error if a circular dependency of `before` references is detected.

If an error is thrown, the registry is left unchanged.

Returns a list of architectures that were added.
"""

new_architectures = list(self._architectures)
new_by_id = dict(self._by_id)
added = []
for arch in architectures:
if arch.architecture.id in new_by_id:
raise ValueError(f"Duplicate architecture: {arch.architecture.id}")
if ignore_duplicates:
continue
raise DuplicateArchitectureError(
f"Duplicate architecture: {arch.architecture.id}"
)

new_architectures.append(arch)
new_by_id[arch.architecture.id] = arch
added.append(arch)

new_ordered = ArchRegistry._get_ordered(new_architectures)

self._architectures = new_architectures
self._ordered = new_ordered
self._by_id = new_by_id
return added

@staticmethod
def _get_ordered(architectures: list[ArchSupport]) -> list[ArchSupport]:
Expand Down
8 changes: 7 additions & 1 deletion libs/spandrel/spandrel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
StateDict,
UnsupportedDtypeError,
)
from .__helpers.registry import ArchRegistry, ArchSupport, UnsupportedModelError
from .__helpers.registry import (
ArchRegistry,
ArchSupport,
DuplicateArchitectureError,
UnsupportedModelError,
)

__all__ = [
"ArchId",
"Architecture",
"ArchRegistry",
"ArchSupport",
"canonicalize_state_dict",
"DuplicateArchitectureError",
"ImageModelDescriptor",
"MAIN_REGISTRY",
"MaskedImageModelDescriptor",
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel_extra_arches/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ pip install spandrel spandrel_extra_arches
## Basic usage

```python
from spandrel import MAIN_REGISTRY, ModelLoader
from spandrel_extra_arches import EXTRA_REGISTRY
import spandrel
import spandrel_extra_arches

# add extra architectures before `ModelLoader` is used
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
spandrel_extra_arches.install()

# load a model from disk
model = ModelLoader().load_from_file(r"path/to/model.pth")
model = spandrel.ModelLoader().load_from_file(r"path/to/model.pth")

... # use model
```
Expand Down
25 changes: 24 additions & 1 deletion libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from spandrel import ArchRegistry, ArchSupport
from __future__ import annotations

from spandrel import (
MAIN_REGISTRY,
ArchRegistry,
ArchSupport,
)

from .architectures import (
MAT,
Expand Down Expand Up @@ -27,3 +33,20 @@
ArchSupport.from_architecture(MPRNet.MPRNetArch()),
ArchSupport.from_architecture(MIRNet2.MIRNet2Arch()),
)


def install(*, ignore_duplicates: bool = False) -> list[ArchSupport]:
"""
Try to install the extra architectures into the main registry.

If `ignore_duplicates` is True, the function will not raise an error
if the installation fails due to any of the architectures having already
been installed (but they won't be replaced by ones from this package).
"""
return MAIN_REGISTRY.add(*EXTRA_REGISTRY, ignore_duplicates=ignore_duplicates)


__all__ = [
"EXTRA_REGISTRY",
"install",
]
7 changes: 5 additions & 2 deletions libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .__helper import EXTRA_REGISTRY
from .__helper import EXTRA_REGISTRY, install

__version__ = "0.1.1"

__all__ = ["EXTRA_REGISTRY"]
__all__ = [
"EXTRA_REGISTRY",
"install",
]
6 changes: 3 additions & 3 deletions scripts/dump_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@
sys.path.insert(0, __file__ + "/../../libs/spandrel")
sys.path.insert(0, __file__ + "/../../libs/spandrel_extra_arches")

from spandrel import MAIN_REGISTRY, ModelLoader # noqa: E402
from spandrel_extra_arches import EXTRA_REGISTRY # noqa: E402
import spandrel_extra_arches # noqa: E402
from spandrel import ModelLoader # noqa: E402
except ImportError:
print("Unable to import spandrel.")
print("Follow the contributing guide to set up editable installs.")
raise

MAIN_REGISTRY.add(*EXTRA_REGISTRY)
spandrel_extra_arches.install()

State = Dict[str, object]

Expand Down
4 changes: 2 additions & 2 deletions tests/__snapshots__/test_registry.ambr
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# serializer version: 1
# name: test_registry_add_invalid
ValueError('Duplicate architecture: b')
DuplicateArchitectureError('Duplicate architecture: b')
# ---
# name: test_registry_add_invalid.1
ValueError('Duplicate architecture: test')
DuplicateArchitectureError('Duplicate architecture: test')
# ---
# name: test_registry_add_invalid.2
ValueError('Circular dependency in architecture detection: 1 -> 2 -> 1')
Expand Down
10 changes: 7 additions & 3 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from bs4 import BeautifulSoup, Tag
from syrupy.filters import props

import spandrel_extra_arches
from spandrel import (
MAIN_REGISTRY,
Architecture,
ImageModelDescriptor,
ModelDescriptor,
Expand All @@ -40,9 +40,13 @@
)
from spandrel.__helpers.model_descriptor import StateDict
from spandrel.util import KeyCondition
from spandrel_extra_arches import EXTRA_REGISTRY

MAIN_REGISTRY.add(*EXTRA_REGISTRY)
# The asserts check that the install function first does return
# the newly installed architectures and then does not return them
# when requested to ignore duplicates.
_installed_extras = spandrel_extra_arches.install()
assert len(_installed_extras) > 0
assert len(spandrel_extra_arches.install(ignore_duplicates=True)) == 0

TEST_DIR = Path("./tests/").resolve()
MODEL_DIR = TEST_DIR / "models"
Expand Down
Loading