Skip to main content

Pytorch implementation of normalizing flows

Project description

Normalizing Flows

Code Style: Black License: MIT PyPI

This is a PyTorch implementation of several normalizing flows, including a variational autoencoder. It is used in the articles A Gradient Based Strategy for Hamiltonian Monte Carlo Hyperparameter Optimization and Resampling Base Distributions of Normalizing Flows.

Implemented Flows

Note that Neural Spline Flows with circular and non-circular coordinates are also supported.

Methods of Installation

The latest version of the package can be installed via pip

pip install normflows

At least Python 3.7 is required. If you want to use a GPU, make sure that PyTorch is set up correctly by following the instructions at the PyTorch website.

To run the example notebooks clone the repository first

git clone https://github.com/VincentStimper/normalizing-flows.git

and then install the dependencies.

pip install -r requirements_examples.txt

Usage

A normalizing flow consists of a base distribution, defined in nf.distributions.base, and a list of flows, given in nf.flows. Let's assume our target is a 2D distribution. We pick a diagonal Gaussian base distribution, which is the most popular choice. Our flow shall be a Real NVP model and, therefore, we need to define a neural network for computing the parameters of the affine coupling map. One dimension is used to compute the scale and shift parameter for the other dimension. After each coupling layer we swap their roles.

import normflows as nf

# Define 2D base distribution
base = nf.distributions.base.DiagGaussian(2)

# Define list of flows
num_layers = 16
flows = []
for i in range(num_layers):
    # Neural network with two hidden layers having 32 units each
    # Last layer is initialized by zeros making training more stable
    param_map = nf.nets.MLP([1, 32, 32, 2], init_zeros=True)
    # Add flow layer
    flows.append(nf.flows.AffineCouplingBlock(param_map))
    # Swap dimensions
    flows.append(nf.flows.Permute(2, mode='swap'))

Once they are set up, we can define a nf.NormalizingFlow model. If the target density is available, it can be added to the model to be used during training. Sample target distributions are given in nf.distributions.target.

# If the target density is not given
model = nf.NormalizingFlow(base, flows)

# If the target density is given
target = nf.distributions.target.TwoMoons()
model = nf.NormalizingFlow(base, flows, target)

The loss can be computed with the methods of the model and minimized.

# When doing maximum likelihood learning, i.e. minimizing the forward KLD
# with no target distribution given
loss = model.forward_kld(x)

# When minimizing the reverse KLD based on the given target distribution
loss = model.reverse_kld(num_samples=1024)

# Optimization as usual
loss.backward()
optimizer.step()

For more illustrative examples of how to use the package see the example directory. More advanced experiments can be done with the scripts listed in the repository about resampled base distributions, see its experiments folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

normflows-1.4.tar.gz (50.0 kB view details)

Uploaded Source

File details

Details for the file normflows-1.4.tar.gz.

File metadata

  • Download URL: normflows-1.4.tar.gz
  • Upload date:
  • Size: 50.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.11

File hashes

Hashes for normflows-1.4.tar.gz
Algorithm Hash digest
SHA256 fba1f63a515cfa35b42edb9f06354240b63c82bfc5502aab1ce0cbf5094b94d3
MD5 4458bdafecaa4434d27d09407f319541
BLAKE2b-256 941c1966ac14ec1363e757fda7a6da75f10b92040b99440b24c6c0f5c4ab8af6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page