Give your project support for a variety of PyTorch model architectures, including auto-detecting model architecture from just .pth files. spandrel gives you arch support.
Project description
Spandrel
This package ports chaiNNer's PyTorch architecture support and model loading functionality into its own package, and wraps it into an easy-to-use API.
After seeing many projects extract out chaiNNer's model support into their own projects, I decided it was probably worth the effort of creating a PyPi package that those developers could use instead.
Slightly selfishly, I'm also hoping this will encourage the community to help add support for more models, so I don't have to do it myself. This will ultimately benefit everyone.
This package does not yet have easy inference code for these model types, but porting that code is planned as well.
Installation
Spandrel is available through pip and can be installed via a simple pip install command:
pip install spandrel
Usage
This package is still in early stages of development, and is subject to change at any time.
To use this package for automatic architecture loading, simply use the ModelLoader class like so:
from spandrel import ModelLoader
import torch
# Initialize the ModelLoader class with an optional preferred torch.device. Defaults to cpu.
model_loader = ModelLoader(torch.device("cuda:0"))
# Load the model from the given path
loaded_model = model_loader.load_from_file(r"/path/to/your/model.pth")
And that's it. The model gets loaded into a helper class with various helpful bits of information, as well as the actual model information.
# The model itself (a torch.nn.Module loaded with the weights)
loaded_model.model
# The state dict of the model (the weights)
loaded_model.state_dict
# The architecture of the model (e.g. "ESRGAN")
loaded_model.architecture
# A list of tags for the model, usually describing the size (e.g. ["64nf", "large"])
loaded_model.tags
# A boolean indicating whether the model supports half precision (fp16)
loaded_model.supports_half
# A boolean indicating whether the model supports bfloat16 precision
loaded_model.supports_bfloat16
# The scale of the model (e.g. 4)
loaded_model.scale
# The number of input channels of the model (e.g. 3)
loaded_model.input_channels
# The number of output channels of the model (e.g. 3)
loaded_model.output_channels
# A SizeRequirements object describing the image size requirements of the model
# i.e the minimum size, the multiple of size, and whether the model requires a square input
loaded_model.size_requirements
ModelDescriptors also support moving the model to other devices directly, so you can call .to
on it just like you would the direct model, for convenience.
Model Architecture Support
Spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their files alone.
This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work.
Pytorch
Single Image Super Resolution
- ESRGAN (RRDBNet)
- This includes regular ESRGAN, ESRGAN+, "new-arch ESRGAN" (RealSR, BSRGAN), SPSR, and Real-ESRGAN
- Models: Community ESRGAN | ESRGAN+ | BSRGAN | RealSR | Real-ESRGAN
- Real-ESRGAN Compact (SRVGGNet) | Models
- Swift-SRGAN | Models
- SwinIR | Models
- Swin2SR | Models
- HAT | Models
- Omni-SR | Models
- SRFormer | Models
- DAT | Models
Face Restoration
- GFPGAN | 1.2, 1.3, 1.4
- RestoreFormer | Model
- CodeFormer | Model
Inpainting
Denoising
DeJPEG
File type support
Spandrel mainly supports loading .pth
files for all supported architectures. This is what you will typically find from official repos and community trained models. However, Spandrel also supports loading TorchScript traced models (.pt
), certain types of .ckpt
files, as well as any supported model that has been saved as or converted to a .safetensors
file.
Security
As you may know, loading .pth
files usually poses a security risk due to python's pickle
module being unsafe and vulnerable to arbitrary code execution (ACE). Because of this, Spandrel uses a custom unpickler function that only allows loading certain types of data out of a .pth file. This ideally prevents ACE and makes loading untrusted files more secure. Note that there still could be the possibility of ACE (though we don't expect this to be the case), so if you're still concerned about security, only load .safetensors models.
Contributing
Feel free to contribute more model architecture support. When I add model support, I usually dig through the .pth file (state dict) keys and weights to find a way to get all the parameters of a model. At some point, I will document that entire process here. For now, there are plenty of example to reference.
If the model arch you're adding does not have any parameter variants (for example, different scales or layer counts) then it should be fine adding it without any of the param detection. At the very least, you will need to find something uniquely identifiable in your model (usually a unique, really long key) that you can then add to /spandrel/__helpers/main_registry.py
in order to load your model (preferably at the bottom). You will also need to set up the __init__.py
file for your arch to include a load
method, returning as ModelDescriptor with the model and some metadata about the model and its parameters.
Like with the parameter detection, there's plenty of examples there. This might seem like a lot of hardcoding (and it very well is), but it's the only way to identify models based on just the .pth file (or any other weight storage format), since these files are just the weights of a model. If anybody can figure out a better way to do this, be my guest, but for now this is the best way and it works well.
License Notice
This repo is bounded by GPLv3 license. However, all the architectures used in this repository are bound by their own original licenses, which have been included in their respective places in this repo. The state dict parsing (load.py) files are not bound by these original licenses as they are new code.
The original code has also been slightly modified and formatted to fit the needs of this repo. If you want to use these architectures in your own codebase (but why would you if you have this package 😉), I recommend grabbing them from their original sources.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.