Skip to content

Commit

Permalink
Add a convenience install() function for extra architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Jul 11, 2024
1 parent 400e517 commit e7e4126
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 18 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ Note that `model` is a [`ModelDescriptor`](https://chainner.app/spandrel/#ModelD
If you are working on a non-commercial open-source project or a private project, you should use `spandrel` and `spandrel_extra_arches` to get everything spandrel has to offer. The `spandrel` package only contains architectures with [permissive and public domain licenses](https://en.wikipedia.org/wiki/Permissive_software_license) (MIT, Apache 2.0, public domain), so it is fit for every use case. Architectures with restrictive licenses (e.g. non-commercial) are implemented in the `spandrel_extra_arches` package.

```python
from spandrel import ImageModelDescriptor, MAIN_REGISTRY, ModelLoader
from spandrel_extra_arches import EXTRA_REGISTRY
import torch
import spandrel
import spandrel_extra_arches

# add extra architectures before `ModelLoader` is used
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
spandrel_extra_arches.install()

# load a model from disk
model = ModelLoader().load_from_file(r"path/to/model.pth")
model = spandrel.ModelLoader().load_from_file(r"path/to/model.pth")

... # use model
```
Expand Down
8 changes: 4 additions & 4 deletions libs/spandrel_extra_arches/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ pip install spandrel spandrel_extra_arches
## Basic usage

```python
from spandrel import MAIN_REGISTRY, ModelLoader
from spandrel_extra_arches import EXTRA_REGISTRY
import spandrel
import spandrel_extra_arches

# add extra architectures before `ModelLoader` is used
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
spandrel_extra_arches.install()

# load a model from disk
model = ModelLoader().load_from_file(r"path/to/model.pth")
model = spandrel.ModelLoader().load_from_file(r"path/to/model.pth")

... # use model
```
Expand Down
23 changes: 22 additions & 1 deletion libs/spandrel_extra_arches/spandrel_extra_arches/__helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from spandrel import ArchRegistry, ArchSupport
from spandrel import (
MAIN_REGISTRY,
ArchRegistry,
ArchSupport,
)

from .architectures import (
MAT,
Expand Down Expand Up @@ -27,3 +31,20 @@
ArchSupport.from_architecture(MPRNet.MPRNetArch()),
ArchSupport.from_architecture(MIRNet2.MIRNet2Arch()),
)


def install(*, ignore_duplicates: bool = False) -> list:
"""
Try to install the extra architectures into the main registry.
If `ignore_duplicates` is True, the function will not raise an error
if the installation fails due to any of the architectures having already
been installed (but they won't be replaced by ones from this package).
"""
return MAIN_REGISTRY.add(*EXTRA_REGISTRY, ignore_duplicates=ignore_duplicates)


__all__ = [
"EXTRA_REGISTRY",
"install",
]
7 changes: 5 additions & 2 deletions libs/spandrel_extra_arches/spandrel_extra_arches/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .__helper import EXTRA_REGISTRY
from .__helper import EXTRA_REGISTRY, install

__version__ = "0.1.1"

__all__ = ["EXTRA_REGISTRY"]
__all__ = [
"EXTRA_REGISTRY",
"install",
]
6 changes: 3 additions & 3 deletions scripts/dump_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@
sys.path.insert(0, __file__ + "/../../libs/spandrel")
sys.path.insert(0, __file__ + "/../../libs/spandrel_extra_arches")

from spandrel import MAIN_REGISTRY, ModelLoader # noqa: E402
from spandrel_extra_arches import EXTRA_REGISTRY # noqa: E402
import spandrel_extra_arches # noqa: E402
from spandrel import ModelLoader # noqa: E402
except ImportError:
print("Unable to import spandrel.")
print("Follow the contributing guide to set up editable installs.")
raise

MAIN_REGISTRY.add(*EXTRA_REGISTRY)
spandrel_extra_arches.install()

State = Dict[str, object]

Expand Down
9 changes: 6 additions & 3 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from bs4 import BeautifulSoup, Tag
from syrupy.filters import props

import spandrel_extra_arches
from spandrel import (
MAIN_REGISTRY,
Architecture,
ImageModelDescriptor,
ModelDescriptor,
Expand All @@ -40,9 +40,12 @@
)
from spandrel.__helpers.model_descriptor import StateDict
from spandrel.util import KeyCondition
from spandrel_extra_arches import EXTRA_REGISTRY

MAIN_REGISTRY.add(*EXTRA_REGISTRY)
# The asserts check that the install function first does return
# the newly installed architectures and then does not return them
# when requested to ignore duplicates.
assert spandrel_extra_arches.install()
assert not spandrel_extra_arches.install(ignore_duplicates=True)

TEST_DIR = Path("./tests/").resolve()
MODEL_DIR = TEST_DIR / "models"
Expand Down

0 comments on commit e7e4126

Please sign in to comment.