Skip to main content

Give your project support for a variety of PyTorch model architectures, including auto-detecting model architecture from just .pth files. spandrel gives you arch support.

Project description

Spandrel

PyPI package version number PyPi Downloads Python Version

Actions Status License Contributors

This package ports chaiNNer's PyTorch architecture support and model loading functionality into its own package, and wraps it into an easy-to-use API.

After seeing many projects extract out chaiNNer's model support into their own projects, I decided it was probably worth the effort of creating a PyPi package that those developers could use instead.

I'm also hoping that by having a central package anyone can use, the community will be encouraged to help add support for more models. This will ultimately benefit everyone.

This package does not yet have easy inference code, but porting that code is planned as well.

Installation

Spandrel is available through pip and can be installed via a simple pip install command:

pip install spandrel

Usage

This package is still in early stages of development, and is subject to change at any time.

To use this package for automatic architecture loading, simply use the ModelLoader class like so:

from spandrel import ModelLoader
import torch

# Initialize the ModelLoader class with an optional preferred torch.device. Defaults to cpu.
model_loader = ModelLoader(torch.device("cuda:0"))

# Load the model from the given path
loaded_model = model_loader.load_from_file(r"/path/to/your/model.pth")

And that's it. The model gets loaded into a helper class called a ModelDescriptor with various helpful bits of information, as well as the actual model information.

# The model itself (a torch.nn.Module loaded with the weights)
loaded_model.model

# The architecture of the model (e.g. "ESRGAN")
loaded_model.architecture

# A list of tags for the model, usually describing the size (e.g. ["64nf", "large"])
loaded_model.tags

# A boolean indicating whether the model supports half precision (fp16)
loaded_model.supports_half

# A boolean indicating whether the model supports bfloat16 precision
loaded_model.supports_bfloat16

# The scale of the model (e.g. 4)
loaded_model.scale

# The number of input channels of the model (e.g. 3)
loaded_model.input_channels

# The number of output channels of the model (e.g. 3)
loaded_model.output_channels

# A SizeRequirements object describing the image size requirements of the model
# i.e the minimum size, the multiple of size, and whether the model requires a square input
loaded_model.size_requirements

ModelDescriptors also support basic inference, with per-descriptor parameters to keep everything simple. For example, an ImageModelDescriptor (used for super-resolution and restoration) takes in a single image tensor and returns a single image tensor, whereas a MaskedModelDescriptor (used for inpainting) takes in an image tensor and a mask tensor and returns a single image tensor.

NOTE: This is not an inference wrapper in the sense that it wil convert an image to a tensor for you. This is purely making the forward passes of these models more convenient to use, since the actual forward passes are not always as simple as image in/image out.

ModelDescriptors also have a few convenience methods to make them more similar to regular torch.nn.Modules: .to, .train, and .eval.

Example:

model = ModelLoader().load_from_file(r"/path/to/your/model.pth")
model.to("cuda:0")
model.eval()
def process(tensor: Tensor) -> Tensor:
    with torch.no_grad():
        return model(tensor)

Model Architecture Support

Spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their files alone.

NOTE: By its very nature, Spandrel will never be able to support every model architecture. The goal is just to support as many as is realistically possible.

This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work.

Pytorch

Single Image Super Resolution

Face Restoration

Inpainting

Denoising

DeJPEG

File type support

Spandrel mainly supports loading .pth files for all supported architectures. This is what you will typically find from official repos and community trained models. However, Spandrel also supports loading TorchScript traced models (.pt), certain types of .ckpt files, and .safetensors files for any supported architecture saved in one of these formats.

Security

As you may know, loading .pth files usually poses a security risk due to python's pickle module being unsafe and vulnerable to arbitrary code execution (ACE). Because of this, Spandrel uses a custom unpickler function that only allows loading certain types of data out of a .pth file. This ideally prevents ACE and makes loading untrusted files more secure. Note that there still could be the possibility of ACE (though we don't expect this to be the case), so if you're still concerned about security, only load .safetensors models.

License Notice

This repo is bounded by GPLv3 license. However, all the architectures used in this repository are bound by their own original licenses, which have been included in their respective places in this repo. The state dict parsing (load.py) files are not bound by these original licenses as they are new code.

The original code has also been slightly modified and formatted to fit the needs of this repo. If you want to use these architectures in your own codebase (but why would you if you have this package 😉), I recommend grabbing them from their original sources.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

spandrel-0.1.0-py3-none-any.whl (245.3 kB view details)

Uploaded Python 3

File details

Details for the file spandrel-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: spandrel-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 245.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for spandrel-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0105d42e59ff9c4dcbdd6a855d0d4aa562511726227a166bfff60e4d1ba24342
MD5 a7679ba9007d8c963e86dc418e2137f8
BLAKE2b-256 667a3aa61de453a19845c058dd012235e764d76eae0ca106e8e5f44e3d39b987

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page