Skip to main content

Normalizing Flows for JAX

Project description

Normalizing Flows in JAX

Build GitHub Documentation

Implementations of normalizing flows (RealNVP, Glow, MAF) in the JAX deep learning framework.

What are normalizing flows?

Normalizing flow models are generative models, i.e. they infer the underlying probability distribution of an observed dataset. With that distribution we can do a number of interesting things, namely sample new realistic points and query probability densities.

Why JAX?

A few reasons!

  1. JAX encourages a functional style. When writing a layer, I didn't want people to worry about PyTorch or TensorFlow boilerplate and how their code has to fit into "the system" (e.g. do I have to keep track of self.training here?) All you have to worry about is writing a vanilla python function which, given an ndarray, returns the correct set of outputs. You could develop your own layers with effectively no knowledge of the encompassing framework.

  2. JAX's random number generation system places reproducibility first. To get a sense for this, when you start to parallelize a system, centralized state-based models for PRNG a la torch.manual_seed() or tf.random.set_seed() start to yield inconsistent results. Given that randomness is such a central component to work in this area, I thought that uncompromising reproducibility would be a nice feature.

  3. JAX has a really flexible automatic differentiation system. So flexible, in fact, that you can (basically) write arbitrary python functions (including for loops, if statements, etc.) and automatically compute their jacobian with a call to jax.jacfwd. So, in theory, you could write a normalizing flow layer and automatically compute its jacobian's log determinant without having to do so manually (although we're not quite there yet).

How do things work?

Here's an introduction! But for a more comprehensive description, check out the documentation.

Bijections

A bijection is a parameterized invertible function.

init_fun = flows.InvertibleLinear()

params, direct_fun, inverse_fun = init_fun(rng, input_dim=5)

# Transform inputs
transformed_inputs, log_det_jacobian_direct = direct_fun(params, inputs)

# Reconstruct original inputs
reconstructed_inputs, log_det_jacobian_inverse = inverse_fun(params, transformed_inputs)

assert np.array_equal(inputs, reconstructed_inputs)

We can construct a sequence of bijections using flows.Serial. The result is just another bijection, and adheres to the exact same interface.

init_fun = flows.Serial(
    flows.AffineCoupling()
    flows.InvertibleLinear(),
    flows.ActNorm(),
)

params, direct_fun, inverse_fun = init_fun(rng, input_dim=5)

Distributions

A distribution is characterized by a probability density querying function, a sampling function, and its parameters.

init_fun = flows.Normal()

params, log_pdf, sample = init_fun(rng, input_dim=5)

# Query probability density of points
log_pdfs = log_pdf(params, inputs)

# Draw new points
samples = sample(rng, params, num_samples)

Normalizing Flow Models

Under this definition, a normalizing flow model is just a distribution. But to retrieve one, we have to give it a bijection and another distribution to act as a prior.

bijection = flows.Serial(
    flows.AffineCoupling(),
    flows.InvertibleLinear(),
    flows.ActNorm()
    flows.AffineCoupling(),
    flows.InvertibleLinear(),
    flows.ActNorm()
)

prior = flows.Normal()

init_fun = flows.Flow(bijection, prior)

params, log_pdf, sample = init_fun(rng, input_dim=5)

How do I train a model?

The same as you always would in JAX! First, define an appropriate loss function and parameter update step.

def loss(params, inputs):
    return -log_pdf(params, inputs).mean()

@jit
def step(i, opt_state, inputs):
    params = get_params(opt_state)
    gradient = grad(loss)(params, inputs)
    return opt_update(i, gradient, opt_state)

Then execute a standard JAX training loop.

batch_size = 32
itercount = itertools.count()
for epoch in range(num_epochs):
    npr.shuffle(X)
    for batch_index in range(0, len(X), batch_size):
        opt_state = step(
            next(itercount),
            opt_state,
            X[batch_index:batch_index+batch_size]
        )

optimized_params = get_params(opt_state)

Now that we have our trained model parameters, we can query and sample as regular.

log_pdfs = log_pdf(optimized_params, inputs)

samples = sample(rng, optimized_params, num_samples)

Magic!

Interested in contributing?

Yay! Check out our contributing guidelines.

Inspiration

This repository is largely modeled after the pytorch-flows repository by Ilya Kostrikov, the nf-jax repository by Eric Jang, and the normalizing-flows repository by Tony Duan.

The implementations are modeled after the work of the following papers:

NICE: Non-linear Independent Components Estimation
Laurent Dinh, David Krueger, Yoshua Bengio
arXiv:1410.8516

Density estimation using Real NVP
Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio
arXiv:1605.08803

Improving Variational Inference with Inverse Autoregressive Flow
Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
arXiv:1606.04934

Glow: Generative Flow with Invertible 1x1 Convolutions
Diederik P. Kingma, Prafulla Dhariwal
arXiv:1807.03039

Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design
Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel
OpenReview:Hyg74h05tX

Masked Autoregressive Flow for Density Estimation
George Papamakarios, Theo Pavlakou, Iain Murray
arXiv:1705.07057

Neural Spline Flows
Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios
arXiv:1906.04032

And by association the following surveys:

Normalizing Flows: An Introduction and Review of Current Methods
Ivan Kobyzev, Simon Prince, Marcus A. Brubaker
arXiv:1908.09257

Normalizing Flows for Probabilistic Modeling and Inference
George Papamakarios, Eric Nalisnick, Danilo Jimenez Rezende, Shakir Mohamed, Balaji Lakshminarayanan
arXiv:1912.02762

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-flows-0.0.1.tar.gz (16.4 kB view details)

Uploaded Source

Built Distribution

jax_flows-0.0.1-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file jax-flows-0.0.1.tar.gz.

File metadata

  • Download URL: jax-flows-0.0.1.tar.gz
  • Upload date:
  • Size: 16.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.7.6

File hashes

Hashes for jax-flows-0.0.1.tar.gz
Algorithm Hash digest
SHA256 4178010129cd116fa883a24cca93dd9c87a33d4487707d1c1ba9f4b01d1eb741
MD5 c6de4b04d97a818519aee91208ba3a6c
BLAKE2b-256 4cee6b1b9049ba4c3e0593ae42391d9fff7853e75ca33e606304577837e82b82

See more details on using hashes here.

File details

Details for the file jax_flows-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: jax_flows-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 15.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.7.6

File hashes

Hashes for jax_flows-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c144c5247f5af39c98f5be796a6d9c546a740fa9e25948f1443ef0cf1257d642
MD5 dc8f2bd5b94d21d02465af2eafd7599f
BLAKE2b-256 fe01c5be3048690dcb5f1f03185ef15832847f87977bc22f0c1d116ecc353e8d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page