Give your project support for a variety of PyTorch model architectures, including auto-detecting model architecture from just .pth files. spandrel gives you arch support.
Project description
Spandrel
This package ports chaiNNer's PyTorch architecture support and model loading functionality into its own package, and wraps it into an easy-to-use API.
After seeing many projects extract out chaiNNer's model support into their own projects, I decided it was probably worth the effort of creating a PyPi package that those developers could use instead.
I'm also hoping that by having a central package anyone can use, the community will be encouraged to help add support for more models. This will ultimately benefit everyone.
This package does not yet have easy inference code, but porting that code is planned as well.
Installation
Spandrel is available through pip and can be installed via a simple pip install command:
pip install spandrel
Usage
This package is still in early stages of development, and is subject to change at any time.
To use this package for automatic architecture loading, simply use the ModelLoader class like so:
from spandrel import ModelLoader
import torch
# Initialize the ModelLoader class with an optional preferred torch.device. Defaults to cpu.
model_loader = ModelLoader(torch.device("cuda:0"))
# Load the model from the given path
loaded_model = model_loader.load_from_file(r"/path/to/your/model.pth")
And that's it. The model gets loaded into a helper class called a ModelDescriptor with various helpful bits of information, as well as the actual model information.
# The model itself (a torch.nn.Module loaded with the weights)
loaded_model.model
# The architecture of the model (e.g. "ESRGAN")
loaded_model.architecture
# A list of tags for the model, usually describing the size (e.g. ["64nf", "large"])
loaded_model.tags
# A boolean indicating whether the model supports half precision (fp16)
loaded_model.supports_half
# A boolean indicating whether the model supports bfloat16 precision
loaded_model.supports_bfloat16
# The scale of the model (e.g. 4)
loaded_model.scale
# The number of input channels of the model (e.g. 3)
loaded_model.input_channels
# The number of output channels of the model (e.g. 3)
loaded_model.output_channels
# A SizeRequirements object describing the image size requirements of the model
# i.e the minimum size, the multiple of size, and whether the model requires a square input
loaded_model.size_requirements
ModelDescriptors also support basic inference, with per-descriptor parameters to keep everything simple. For example, an ImageModelDescriptor
(used for super-resolution and restoration) takes in a single image tensor and returns a single image tensor, whereas a MaskedModelDescriptor
(used for inpainting) takes in an image tensor and a mask tensor and returns a single image tensor.
NOTE: This is not an inference wrapper in the sense that it wil convert an image to a tensor for you. This is purely making the forward passes of these models more convenient to use, since the actual forward passes are not always as simple as image in/image out.
ModelDescriptors also have a few convenience methods to make them more similar to regular torch.nn.Module
s: .to
, .train
, and .eval
.
Example:
model = ModelLoader().load_from_file(r"/path/to/your/model.pth")
model.to("cuda:0")
model.eval()
def process(tensor: Tensor) -> Tensor:
with torch.no_grad():
return model(tensor)
Model Architecture Support
Spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their files alone.
NOTE: By its very nature, Spandrel will never be able to support every model architecture. The goal is just to support as many as is realistically possible.
This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work.
Pytorch
Single Image Super Resolution
- ESRGAN (RRDBNet)
- This includes regular ESRGAN, ESRGAN+, "new-arch ESRGAN" (RealSR, BSRGAN), SPSR, and Real-ESRGAN
- Models: Community ESRGAN | ESRGAN+ | BSRGAN | RealSR | Real-ESRGAN
- Real-ESRGAN Compact (SRVGGNet) | Models
- Swift-SRGAN | Models
- SwinIR | Models
- Swin2SR | Models
- HAT | Models
- Omni-SR | Models
- SRFormer | Models
- DAT | Models
- FeMaSR | Models
- GRLIR | Models
- DITN | Models
Face Restoration
- GFPGAN | 1.2, 1.3, 1.4
- RestoreFormer | Model
- CodeFormer | Model
Inpainting
Denoising
DeJPEG
File type support
Spandrel mainly supports loading .pth
files for all supported architectures. This is what you will typically find from official repos and community trained models. However, Spandrel also supports loading TorchScript traced models (.pt
), certain types of .ckpt
files, and .safetensors
files for any supported architecture saved in one of these formats.
Security
As you may know, loading .pth
files usually poses a security risk due to python's pickle
module being unsafe and vulnerable to arbitrary code execution (ACE). Because of this, Spandrel uses a custom unpickler function that only allows loading certain types of data out of a .pth file. This ideally prevents ACE and makes loading untrusted files more secure. Note that there still could be the possibility of ACE (though we don't expect this to be the case), so if you're still concerned about security, only load .safetensors models.
License Notice
This repo is bounded by GPLv3 license. However, all the architectures used in this repository are bound by their own original licenses, which have been included in their respective places in this repo. The state dict parsing (load.py) files are not bound by these original licenses as they are new code.
The original code has also been slightly modified and formatted to fit the needs of this repo. If you want to use these architectures in your own codebase (but why would you if you have this package 😉), I recommend grabbing them from their original sources.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.