Pytorch implementation of normalizing flows
Project description
Normalizing Flows
This is a PyTorch implementation of several normalizing flows, including a variational autoencoder. It is used in the articles A Gradient Based Strategy for Hamiltonian Monte Carlo Hyperparameter Optimization and Resampling Base Distributions of Normalizing Flows.
Implemented Flows
- Planar Flow (Rezende & Mohamed, 2015)
- Radial Flow (Rezende & Mohamed, 2015)
- NICE (Dinh et al., 2014)
- Real NVP (Dinh et al., 2016)
- Glow (Kingma & Dhariwal, 2018)
- Masked Autoregressive Flow (Papamakarios et al., 2017)
- Neural Spline Flow (Durkan et al., 2019)
- Circular Neural Spline Flow (Rezende et al., 2020)
- Residual Flow (Chen et al., 2019)
- Stochastic Normalizing Flows (Wu et al., 2020)
Note that Neural Spline Flows with circular and non-circular coordinates are also supported.
Methods of Installation
The latest version of the package can be installed via pip
pip install normflows
At least Python 3.7 is required. If you want to use a GPU, make sure that PyTorch is set up correctly by following the instructions at the PyTorch website.
To run the example notebooks clone the repository first
git clone https://github.com/VincentStimper/normalizing-flows.git
and then install the dependencies.
pip install -r requirements_examples.txt
Usage
A normalizing flow consists of a base distribution, defined in
nf.distributions.base
,
and a list of flows, given in
nf.flows
.
Let's assume our target is a 2D distribution. We pick a diagonal Gaussian
base distribution, which is the most popular choice. Our flow shall be a
Real NVP model and, therefore, we need
to define a neural network for computing the parameters of the affine coupling
map. One dimension is used to compute the scale and shift parameter for the
other dimension. After each coupling layer we swap their roles.
import normflows as nf
# Define 2D base distribution
base = nf.distributions.base.DiagGaussian(2)
# Define list of flows
num_layers = 16
flows = []
for i in range(num_layers):
# Neural network with two hidden layers having 32 units each
# Last layer is initialized by zeros making training more stable
param_map = nf.nets.MLP([1, 32, 32, 2], init_zeros=True)
# Add flow layer
flows.append(nf.flows.AffineCouplingBlock(param_map))
# Swap dimensions
flows.append(nf.flows.Permute(2, mode='swap'))
Once they are set up, we can define a
nf.NormalizingFlow
model. If the target density is available, it can be added to the model
to be used during training. Sample target distributions are given in
nf.distributions.target
.
# If the target density is not given
model = nf.NormalizingFlow(base, flows)
# If the target density is given
target = nf.distributions.target.TwoMoons()
model = nf.NormalizingFlow(base, flows, target)
The loss can be computed with the methods of the model and minimized.
# When doing maximum likelihood learning, i.e. minimizing the forward KLD
# with no target distribution given
loss = model.forward_kld(x)
# When minimizing the reverse KLD based on the given target distribution
loss = model.reverse_kld(num_samples=1024)
# Optimization as usual
loss.backward()
optimizer.step()
For more illustrative examples of how to use the package see the
example
directory. More advanced experiments can be done with the scripts listed in the
repository about resampled base distributions,
see its experiments
folder.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file normflows-1.4.tar.gz
.
File metadata
- Download URL: normflows-1.4.tar.gz
- Upload date:
- Size: 50.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.7.11
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fba1f63a515cfa35b42edb9f06354240b63c82bfc5502aab1ce0cbf5094b94d3 |
|
MD5 | 4458bdafecaa4434d27d09407f319541 |
|
BLAKE2b-256 | 941c1966ac14ec1363e757fda7a6da75f10b92040b99440b24c6c0f5c4ab8af6 |