Skip to main content

Normalizing flows in PyTorch

Project description

Zuko's banner

Zuko - Normalizing flows in PyTorch

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution and Transform classes of torch are not sub-classes of torch.nn.Module, which means you cannot send their internal tensors to GPU with .to('cuda') or retrieve their parameters with .parameters(). Worse, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express.

To solve these problems, zuko defines two concepts: the LazyDistribution and LazyTransform, which are any modules whose forward pass returns a Distribution or Transform, respectively. Because the creation of the actual distribution/transformation is delayed, an eventual condition can be easily taken into account. This design enables lazy distributions, including normalizing flows, to act like distributions while retaining features inherent to modules, such as trainable parameters. It also makes the implementations easy to understand and extend.

In the Avatar cartoon, Zuko is a powerful firebender 🔥

Acknowledgements

Zuko takes significant inspiration from nflows and Stefan Webb's work in Pyro and FlowTorch.

Installation

The zuko package is available on PyPI, which means it is installable via pip.

pip install zuko

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/probabilists/zuko

Getting started

Normalizing flows are provided in the zuko.flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context $c$ to the flow returns a conditional distribution $p(x | c)$ which can be evaluated and sampled from.

import torch
import zuko

# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)

# Train to maximize the log-likelihood
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

for x, c in trainset:
    loss = -flow(c).log_prob(x)  # -log p(x | c)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Sample 64 points x ~ p(x | c*)
x = flow(c_star).sample((64,))

Alternatively, flows can be built as custom Flow objects.

from zuko.flows import Flow, UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform

flow = Flow(
    transform=[
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
        UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
    ],
    base=UnconditionalDistribution(
        DiagNormal,
        loc=torch.zeros(3),
        scale=torch.ones(3),
        buffer=True,
    ),
)

For more information, check out the documentation and tutorials at zuko.readthedocs.io.

Available flows

Class Year Reference
GMM - Gaussian Mixture Model
NICE 2014 Non-linear Independent Components Estimation
RealNVP 2016 Density estimation using Real NVP
MAF 2017 Masked Autoregressive Flow for Density Estimation
NSF 2019 Neural Spline Flows
NCSF 2020 Normalizing Flows on Tori and Spheres
SOSPF 2019 Sum-of-Squares Polynomial Flow
NAF 2018 Neural Autoregressive Flows
UNAF 2019 Unconstrained Monotonic Neural Networks
CNF 2018 Neural Ordinary Differential Equations
GF 2020 Gaussianization Flows
BPF 2020 Bernstein-Polynomial Normalizing Flows

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zuko-1.6.0.tar.gz (45.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zuko-1.6.0-py3-none-any.whl (48.0 kB view details)

Uploaded Python 3

File details

Details for the file zuko-1.6.0.tar.gz.

File metadata

  • Download URL: zuko-1.6.0.tar.gz
  • Upload date:
  • Size: 45.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.22 {"installer":{"name":"uv","version":"0.9.22","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zuko-1.6.0.tar.gz
Algorithm Hash digest
SHA256 edc516e51bbbf9d64e7663b617cf9293c6e1e6bbfcb39559bc383383e6663b04
MD5 2cb3ae9a75ad1ae21c292b0fb4050b08
BLAKE2b-256 8c41ddbe72cb64996d7826ba427c675252be0c38fbd9fbf8920d0fcdaf5d8e38

See more details on using hashes here.

File details

Details for the file zuko-1.6.0-py3-none-any.whl.

File metadata

  • Download URL: zuko-1.6.0-py3-none-any.whl
  • Upload date:
  • Size: 48.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.22 {"installer":{"name":"uv","version":"0.9.22","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zuko-1.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5c073b613a84a7cd65470ddb94855169020ac49432f73b85f25207e377248a4a
MD5 ec9cffdc6bb0cc11dc1b4d1d982079a3
BLAKE2b-256 9510ff159867f522cd98e039e748c5e9777446e8b797a79be691c6b730676094

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page