Normalizing flows in PyTorch
Project description
Zuko - Normalizing flows in PyTorch
Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution
and Transform
classes of torch
are not sub-classes of torch.nn.Module
, which means you cannot send their internal tensors to GPU with .to('cuda')
or retrieve their parameters with .parameters()
. Worse, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express.
To solve these problems, zuko
defines two concepts: the LazyDistribution
and LazyTransform
, which are any modules whose forward pass returns a Distribution
or Transform
, respectively. Because the creation of the actual distribution/transformation is delayed, an eventual condition can be easily taken into account. This design enables lazy distributions, including normalizing flows, to act like distributions while retaining features inherent to modules, such as trainable parameters. It also makes the implementations easy to understand and extend.
Acknowledgements
Zuko takes significant inspiration from nflows and Stefan Webb's work in Pyro and FlowTorch.
Installation
The zuko
package is available on PyPI, which means it is installable via pip
.
pip install zuko
Alternatively, if you need the latest features, you can install it from the repository.
pip install git+https://github.com/probabilists/zuko
Getting started
Normalizing flows are provided in the zuko.flows
module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context $c$ to the flow returns a conditional distribution $p(x | c)$ which can be evaluated and sampled from.
import torch
import zuko
# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)
# Train to maximize the log-likelihood
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
for x, c in trainset:
loss = -flow(c).log_prob(x) # -log p(x | c)
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Sample 64 points x ~ p(x | c*)
x = flow(c_star).sample((64,))
Alternatively, flows can be built as custom Flow
objects.
from zuko.flows import Flow, UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform
flow = Flow(
transform=[
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
],
base=UnconditionalDistribution(
DiagNormal,
torch.zeros(3),
torch.ones(3),
buffer=True,
),
)
For more information, check out the documentation and tutorials at zuko.readthedocs.io.
Available flows
Class | Year | Reference |
---|---|---|
GMM |
- | Gaussian Mixture Model |
NICE |
2014 | Non-linear Independent Components Estimation |
MAF |
2017 | Masked Autoregressive Flow for Density Estimation |
NSF |
2019 | Neural Spline Flows |
NCSF |
2020 | Normalizing Flows on Tori and Spheres |
SOSPF |
2019 | Sum-of-Squares Polynomial Flow |
NAF |
2018 | Neural Autoregressive Flows |
UNAF |
2019 | Unconstrained Monotonic Neural Networks |
CNF |
2018 | Neural Ordinary Differential Equations |
GF |
2020 | Gaussianization Flows |
BPF |
2020 | Bernstein-Polynomial Normalizing Flows |
Contributing
If you have a question, an issue or would like to contribute, please read our contributing guidelines.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file zuko-1.3.0.tar.gz
.
File metadata
- Download URL: zuko-1.3.0.tar.gz
- Upload date:
- Size: 38.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7797b6c069b23a420c839fb274552f1dba4847e9a38292b63ce13cee4dac871a |
|
MD5 | db4b30f100aaeb66f0ed90806e63adee |
|
BLAKE2b-256 | aa9e333bd6ce329e56aecd2a04a6ce42a3221c40bc849909fdcf70d10f2cc1a0 |
File details
Details for the file zuko-1.3.0-py3-none-any.whl
.
File metadata
- Download URL: zuko-1.3.0-py3-none-any.whl
- Upload date:
- Size: 42.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2ecdc5a8b19ca47e7671e0c86101793c3aeb899bf02467c36db85dd3ba6402e3 |
|
MD5 | f47788781cc3978dfe5e072169e3fdd6 |
|
BLAKE2b-256 | 68d8c72ca94e2dc33652ccedb819095e335b115c26092c59071bf03aaca7d8a3 |