Skip to main content

Give your project support for a variety of PyTorch model architectures, including auto-detecting model architecture from just .pth files. spandrel gives you arch support.

Project description

Spandrel

PyPI package version number PyPi Downloads Python Version

Actions Status License Contributors

This package ports chaiNNer's PyTorch architecture support and model loading functionality into its own package, and wraps it into an easy-to-use API.

After seeing many projects extract out chaiNNer's model support into their own projects, I decided it was probably worth the effort of creating a PyPi package that those developers could use instead.

Slightly selfishly, I'm also hoping this will encourage the community to help add support for more models, so I don't have to do it myself. This will ultimately benefit everyone.

This package does not yet have easy inference code for these model types, but porting that code is planned as well.

Installation

Spandrel is available through pip and can be installed via a simple pip install command:

pip install spandrel

Usage

This package is still in early stages of development, and is subject to change at any time.

To use this package for automatic architecture loading, simply use the ModelLoader class like so:

from spandrel import ModelLoader
import torch

# Initialize the ModelLoader class with an optional preferred torch.device. Defaults to cpu.
model_loader = ModelLoader(torch.device("cuda:0"))

# Load the model from the given path
loaded_model = model_loader.load_from_file(r"/path/to/your/model.pth")

And that's it. The model gets loaded into a helper class with various helpful bits of information, as well as the actual model information.

# The model itself (a torch.nn.Module loaded with the weights)
loaded_model.model

# The state dict of the model (the weights)
loaded_model.state_dict

# The architecture of the model (e.g. "ESRGAN")
loaded_model.architecture

# A list of tags for the model, usually describing the size (e.g. ["64nf", "large"])
loaded_model.tags

# A boolean indicating whether the model supports half precision (fp16)
loaded_model.supports_half

# A boolean indicating whether the model supports bfloat16 precision
loaded_model.supports_bfloat16

# The scale of the model (e.g. 4)
loaded_model.scale

# The number of input channels of the model (e.g. 3)
loaded_model.input_channels

# The number of output channels of the model (e.g. 3)
loaded_model.output_channels

# A SizeRequirements object describing the image size requirements of the model
# i.e the minimum size, the multiple of size, and whether the model requires a square input
loaded_model.size_requirements

ModelDescriptors also support moving the model to other devices directly, so you can call .to on it just like you would the direct model, for convenience.

Model Architecture Support

Spandrel currently supports a limited amount of neural network architectures. It can auto-detect these architectures just from their files alone.

This has only been tested with the models that are linked here, and any unofficial variants (especially if changes are made to their architectures) are not guaranteed to work.

Pytorch

Single Image Super Resolution

Face Restoration

Inpainting

Denoising

DeJPEG

File type support

Spandrel mainly supports loading .pth files for all supported architectures. This is what you will typically find from official repos and community trained models. However, Spandrel also supports loading TorchScript traced models (.pt), certain types of .ckpt files, as well as any supported model that has been saved as or converted to a .safetensors file.

Security

As you may know, loading .pth files usually poses a security risk due to python's pickle module being unsafe and vulnerable to arbitrary code execution (ACE). Because of this, Spandrel uses a custom unpickler function that only allows loading certain types of data out of a .pth file. This ideally prevents ACE and makes loading untrusted files more secure. Note that there still could be the possibility of ACE (though we don't expect this to be the case), so if you're still concerned about security, only load .safetensors models.

Contributing

Feel free to contribute more model architecture support. When I add model support, I usually dig through the .pth file (state dict) keys and weights to find a way to get all the parameters of a model. At some point, I will document that entire process here. For now, there are plenty of example to reference.

If the model arch you're adding does not have any parameter variants (for example, different scales or layer counts) then it should be fine adding it without any of the param detection. At the very least, you will need to find something uniquely identifiable in your model (usually a unique, really long key) that you can then add to /spandrel/__helpers/main_registry.py in order to load your model (preferably at the bottom). You will also need to set up the __init__.py file for your arch to include a load method, returning as ModelDescriptor with the model and some metadata about the model and its parameters.

Like with the parameter detection, there's plenty of examples there. This might seem like a lot of hardcoding (and it very well is), but it's the only way to identify models based on just the .pth file (or any other weight storage format), since these files are just the weights of a model. If anybody can figure out a better way to do this, be my guest, but for now this is the best way and it works well.

License Notice

This repo is bounded by GPLv3 license. However, all the architectures used in this repository are bound by their own original licenses, which have been included in their respective places in this repo. The state dict parsing (load.py) files are not bound by these original licenses as they are new code.

The original code has also been slightly modified and formatted to fit the needs of this repo. If you want to use these architectures in your own codebase (but why would you if you have this package 😉), I recommend grabbing them from their original sources.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

spandrel-0.0.2-py3-none-any.whl (180.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page