Parametrizations for PyTorch
Project description
pytorch-parametrizations
Spectral Parametrization
This module provides a PyTorch implementation of the spectral parametrization of the weights of a 2D convolutional layer, as introduced in the paper "Efficient Nonlinear Transforms for Lossy Image Compression" by Johannes Ballé, PCS 2018.
Usage
import torch
import torch.nn.utils.parametrize as parametrize
from pytorch_parametrizations import SpectralParametrization
# Create a 2D convolutional layer
conv = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
# Register the spectral parametrization for the weights of the layer
parametrize.register_parametrization(conv, 'weight', SpectralParametrization(conv.kernel_size), unsafe=True)
print(conv.parametrizations.weight)
# Output:
# conv.parametrizations.weight
# ParametrizationList(
# (0): SpectralParametrization()
# )
register_spectral_parametrization
This function registers the spectral parametrization for every Conv2d and ConvTranspose2d layer in a given module.
Usage
import torch
from pytorch_parametrizations.spectral.utils import register_spectral_parametrization
# Create a module
module = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1),
torch.nn.Sigmoid()
)
# Register the spectral parametrization for every Conv2d and ConvTranspose2d layer in the module
register_spectral_parametrization(module)
print(module)
# Output:
# Sequential(
# (0): ParametrizedConv2d(
# 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
# (parametrizations): ModuleDict(
# (weight): ParametrizationList(
# (0): SpectralParametrization()
# )
# )
# )
# (1): ReLU()
# (2): ParametrizedConv2d(
# 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
# (parametrizations): ModuleDict(
# (weight): ParametrizationList(
# (0): SpectralParametrization()
# )
# )
# )
# (3): ReLU()
# (4): ParametrizedConv2d(
# 64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
# (parametrizations): ModuleDict(
# (weight): ParametrizationList(
# (0): SpectralParametrization()
# )
# )
# )
# (5): Sigmoid()
# )
# Unregister the spectral parametrization for every Conv2d and ConvTranspose2d layer in the module
register_spectral_parametrization(module, undo=True)
print(module)
# Output:
# Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU()
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU()
# (4): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (5): Sigmoid()
# )
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch_parametrizations-1.0.0.tar.gz
.
File metadata
- Download URL: pytorch_parametrizations-1.0.0.tar.gz
- Upload date:
- Size: 3.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3867cdbbfdf211d6f1102fe242b8c30184c6e0818e562a4928a98d0997cef88d |
|
MD5 | 7364cf93611af8cf466ed135996b0f07 |
|
BLAKE2b-256 | 4c5910a55c2f907a265d230671db932936b9b7a95582ca74a5f162e118b32e5c |
File details
Details for the file pytorch_parametrizations-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: pytorch_parametrizations-1.0.0-py3-none-any.whl
- Upload date:
- Size: 5.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0ebf38211f21c89efe98ce840bd63d95ddcbaa6428ea1dc07d139585ecdf615a |
|
MD5 | 1e1dcfe06d3902879274b4106ffd598e |
|
BLAKE2b-256 | 49a25e475bb794cfe200d72abab6d815f1134047448dd44b65715f4738aea221 |