Skip to content

Commit

Permalink
README improvements (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine authored Nov 17, 2023
1 parent 88d8f5d commit cdb0790
Showing 1 changed file with 55 additions and 13 deletions.
68 changes: 55 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,73 @@ Slightly selfishly, I'm also hoping this will encourage the community to help ad

This package does not yet have easy inference code for these model types, but porting that code is planned as well.

## Installation

Spandrel is available through pip and can be installed via a simple pip install command:

```shell
pip install spandrel
```

## Usage

**This package is still in early stages of development, and is subject to change at any time.**

To use this package, simply use the ArchSupport class like so:
To use this package for automatic architecture loading, simply use the ModelLoader class like so:

```python
from spandrel import ArchSupport
from spandrel import ModelLoader
import torch

arch_loader = ArchSupport(torch.device("cuda:0"))
model = arch_loader.load_from_path(r"/path/to/your/model.pth")
# Initialize the ModelLoader class with an optional preferred torch.device. Defaults to cpu.
model_loader = ModelLoader(torch.device("cuda:0"))

print(model.metadata)
print(model.model)
print(model.state_dict)
# Load the model from the given path
loaded_model = model_loader.load_from_file(r"/path/to/your/model.pth")
```

And that's it. The model gets loaded into a wrapper class that has some `metadata` on it that tells you a bit about the model and its size. You can also access the actual torch `model` and `state_dict` from it.
And that's it. The model gets loaded into a helper class with various helpful bits of information, as well as the actual model information.

```py
# The model itself (a torch.nn.Module loaded with the weights)
loaded_model.model

# The state dict of the model (the weights)
loaded_model.state_dict

# The architecture of the model (e.g. "ESRGAN")
loaded_model.architecture

# A list of tags for the model, usually describing the size (e.g. ["64nf", "large"])
loaded_model.tags

You can also just use it for inference the same way you would with the `model` directly, so for example you could do `result = model(img)` and it will automatically call the forward method of the model. It also supports moving it to other devices, so you can call `.to` on it just like you would the direct model.
# A boolean indicating whether the model supports half precision (fp16)
loaded_model.supports_half

# A boolean indicating whether the model supports bfloat16 precision
loaded_model.supports_bfloat16

# The scale of the model (e.g. 4)
loaded_model.scale

# The number of input channels of the model (e.g. 3)
loaded_model.input_channels

# The number of output channels of the model (e.g. 3)
loaded_model.output_channels

# A SizeRequirements object describing the image size requirements of the model
# i.e the minimum size, the multiple of size, and whether the model requires a square input
loaded_model.size
```

You can also just use this helper class for inference the same way you would with the `model` directly, so for example you could do `result = loaded_model(img)` and it will automatically call the forward method of the model. It also supports moving it to other devices, so you can call `.to` on it just like you would the direct model.

## Model Architecture Support

spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their .pth files. This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work.
Spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their files alone.

This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work.

### Pytorch

Expand Down Expand Up @@ -70,11 +112,11 @@ spandrel currently supports a limited amount of neural network architectures. It

## Contributing

Feel free to contribute more model architecture support. When I add model support, I usually dig through the .pth file (state dict) keys and weights to find a way to get all the parameters of a model. At some point, I will document that entire process here. For now, there are plenty of references (most in the super_resolution folder) to reference.
Feel free to contribute more model architecture support. When I add model support, I usually dig through the .pth file (state dict) keys and weights to find a way to get all the parameters of a model. At some point, I will document that entire process here. For now, there are plenty of example to reference.

If the model arch you're adding does not have any parameter variants (for example, different scales or layer counts) then it should be fine adding it without any of the param detection. At the very least, you will need to find something uniquely identifiable in your model (usually a unique, really long key) that you can then add to `/spandrel/__helpers/model_loading.py` in order to load your model (preferably to the bottom of the if block before the else). You will also need to set up the `__init__.py` file for your arch to include a `load` method, returning the model and some metadata about the model and its parameters.
If the model arch you're adding does not have any parameter variants (for example, different scales or layer counts) then it should be fine adding it without any of the param detection. At the very least, you will need to find something uniquely identifiable in your model (usually a unique, really long key) that you can then add to `/spandrel/__helpers/main_registry.py` in order to load your model (preferably at the bottom). You will also need to set up the `__init__.py` file for your arch to include a `load` method, returning as ModelDescriptor with the model and some metadata about the model and its parameters.

Like with the parameter detection, there's plenty of examples there. This might seem like a lot of hardcoding (and it very well is), but it's the only way to identify models based on just the .pth file, since .pth files are just the weights of a model. If anybody can figure out a better way to do this, be my guest, but for now this is the best way and it works well.
Like with the parameter detection, there's plenty of examples there. This might seem like a lot of hardcoding (and it very well is), but it's the only way to identify models based on just the .pth file (or any other weight storage format), since these files are just the weights of a model. If anybody can figure out a better way to do this, be my guest, but for now this is the best way and it works well.

## License Notice

Expand Down

0 comments on commit cdb0790

Please sign in to comment.