Skip to main content

Give your project support for a variety of PyTorch model architectures, including auto-detecting model architecture from just .pth files. spandrel gives you arch support.

Project description

Spandrel

PyPI package version number PyPi Downloads Python Version

Actions Status License Contributors

This package ports chaiNNer's PyTorch architecture support and model loading functionality into its own package, and wraps it into an easy-to-use API.

After seeing many projects extract out chaiNNer's model support into their own projects, I decided it was probably worth the effort of creating a PyPi package that those developers could use instead.

I'm also hoping that by having a central package anyone can use, the community will be encouraged to help add support for more models. This will ultimately benefit everyone.

This package does not yet have easy inference code, but porting that code is planned as well.

Installation

Spandrel is available through pip and can be installed via a simple pip install command:

pip install spandrel

Usage

This package is still in early stages of development, and is subject to change at any time.

To use this package for automatic architecture loading, simply use the ModelLoader class like so:

from spandrel import ModelLoader
import torch

# Initialize the ModelLoader class with an optional preferred torch.device. Defaults to cpu.
model_loader = ModelLoader(torch.device("cuda:0"))

# Load the model from the given path
loaded_model = model_loader.load_from_file(r"/path/to/your/model.pth")

And that's it. The model gets loaded into a helper class called a ModelDescriptor with various helpful bits of information, as well as the actual model information.

# The model itself (a torch.nn.Module loaded with the weights)
loaded_model.model

# The architecture of the model (e.g. "ESRGAN")
loaded_model.architecture

# A list of tags for the model, usually describing the size (e.g. ["64nf", "large"])
loaded_model.tags

# A boolean indicating whether the model supports half precision (fp16)
loaded_model.supports_half

# A boolean indicating whether the model supports bfloat16 precision
loaded_model.supports_bfloat16

# The scale of the model (e.g. 4)
loaded_model.scale

# The number of input channels of the model (e.g. 3)
loaded_model.input_channels

# The number of output channels of the model (e.g. 3)
loaded_model.output_channels

# A SizeRequirements object describing the image size requirements of the model
# i.e the minimum size, the multiple of size, and whether the model requires a square input
loaded_model.size_requirements

ModelDescriptors also support basic inference, with per-descriptor parameters to keep everything simple. For example, an ImageModelDescriptor (used for super-resolution and restoration) takes in a single image tensor and returns a single image tensor, whereas a MaskedModelDescriptor (used for inpainting) takes in an image tensor and a mask tensor and returns a single image tensor.

NOTE: This is not an inference wrapper in the sense that it wil convert an image to a tensor for you. This is purely making the forward passes of these models more convenient to use, since the actual forward passes are not always as simple as image in/image out.

ModelDescriptors also have a few convenience methods to make them more similar to regular torch.nn.Modules: .to, .train, and .eval.

Example:

model = ModelLoader().load_from_file(r"/path/to/your/model.pth")
model.to("cuda:0")
model.eval()
def process(tensor: Tensor) -> Tensor:
    with torch.no_grad():
        return model(tensor)

Model Architecture Support

Spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their files alone.

NOTE: By its very nature, Spandrel will never be able to support every model architecture. The goal is just to support as many as is realistically possible.

This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work.

Pytorch

Single Image Super Resolution

Face Restoration

Inpainting

Denoising

DeJPEG

Colorization

File type support

Spandrel mainly supports loading .pth files for all supported architectures. This is what you will typically find from official repos and community trained models. However, Spandrel also supports loading TorchScript traced models (.pt), certain types of .ckpt files, and .safetensors files for any supported architecture saved in one of these formats.

Security

As you may know, loading .pth files usually poses a security risk due to python's pickle module being unsafe and vulnerable to arbitrary code execution (ACE). Because of this, Spandrel uses a custom unpickler function that only allows loading certain types of data out of a .pth file. This ideally prevents ACE and makes loading untrusted files more secure. Note that there still could be the possibility of ACE (though we don't expect this to be the case), so if you're still concerned about security, only load .safetensors models.

License Notice

This repo is bounded by GPLv3 license. However, all the architectures used in this repository are bound by their own original licenses, which have been included in their respective places in this repo. The state dict parsing (load.py) files are not bound by these original licenses as they are new code.

The original code has also been slightly modified and formatted to fit the needs of this repo. If you want to use these architectures in your own codebase (but why would you if you have this package 😉), I recommend grabbing them from their original sources.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

spandrel-0.1.4-py3-none-any.whl (271.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page