Skip to main content

Parametrizations for PyTorch

Project description

pytorch-parametrizations

Spectral Parametrization

This module provides a PyTorch implementation of the spectral parametrization of the weights of a 2D convolutional layer, as introduced in the paper "Efficient Nonlinear Transforms for Lossy Image Compression" by Johannes Ballé, PCS 2018.

Usage

import torch
import torch.nn.utils.parametrize as parametrize

from pytorch_parametrizations import SpectralParametrization

# Create a 2D convolutional layer
conv = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)

# Register the spectral parametrization for the weights of the layer
parametrize.register_parametrization(conv, 'weight', SpectralParametrization(conv.kernel_size), unsafe=True)


print(conv.parametrizations.weight)
# Output:
# conv.parametrizations.weight
# ParametrizationList(
#   (0): SpectralParametrization()
# )

register_spectral_parametrization

This function registers the spectral parametrization for every Conv2d and ConvTranspose2d layer in a given module.

Usage

import torch
from pytorch_parametrizations.spectral.utils import register_spectral_parametrization

# Create a module
module = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
    torch.nn.ReLU(),
    torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
    torch.nn.ReLU(),
    torch.nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1),
    torch.nn.Sigmoid()
)

# Register the spectral parametrization for every Conv2d and ConvTranspose2d layer in the module
register_spectral_parametrization(module)

print(module)
# Output:
# Sequential(
#   (0): ParametrizedConv2d(
#     3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
#     (parametrizations): ModuleDict(
#       (weight): ParametrizationList(
#         (0): SpectralParametrization()
#       )
#     )
#   )
#   (1): ReLU()
#   (2): ParametrizedConv2d(
#     64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
#     (parametrizations): ModuleDict(
#       (weight): ParametrizationList(
#         (0): SpectralParametrization()
#       )
#     )
#   )
#   (3): ReLU()
#   (4): ParametrizedConv2d(
#     64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
#     (parametrizations): ModuleDict(
#       (weight): ParametrizationList(
#         (0): SpectralParametrization()
#       )
#     )
#   )
#   (5): Sigmoid()
# )

# Unregister the spectral parametrization for every Conv2d and ConvTranspose2d layer in the module
register_spectral_parametrization(module, undo=True)

print(module)
# Output:
# Sequential(
#   (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (1): ReLU()
#   (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (3): ReLU()
#   (4): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   (5): Sigmoid()
# )

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_parametrizations-1.0.0.tar.gz (3.7 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file pytorch_parametrizations-1.0.0.tar.gz.

File metadata

File hashes

Hashes for pytorch_parametrizations-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3867cdbbfdf211d6f1102fe242b8c30184c6e0818e562a4928a98d0997cef88d
MD5 7364cf93611af8cf466ed135996b0f07
BLAKE2b-256 4c5910a55c2f907a265d230671db932936b9b7a95582ca74a5f162e118b32e5c

See more details on using hashes here.

File details

Details for the file pytorch_parametrizations-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_parametrizations-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0ebf38211f21c89efe98ce840bd63d95ddcbaa6428ea1dc07d139585ecdf615a
MD5 1e1dcfe06d3902879274b4106ffd598e
BLAKE2b-256 49a25e475bb794cfe200d72abab6d815f1134047448dd44b65715f4738aea221

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page