Skip to main content

Normalizing Flows for JAX

Project description

Normalizing Flows in JAX

Implementations of normalizing flows (RealNVP, GLOW, MAF) in the JAX deep learning framework.

Build GitHub Documentation GitHub release

What are normalizing flows?

Normalizing flow models are generative models. That is, they infer the probability distribution of a given dataset. With that distribution we can do a number of interesting things, namely query the likelihood of given points as well as sample new realistic points.

How are things structured?

Transformations

A transformation is a parameterized invertible function.

init_fun = flows.MADE()

params, direct_fun, inverse_fun = init_fun(rng, input_shape)

# Transform some inputs
transformed_inputs, log_det_direct = direct_fun(params, inputs)

# Reconstruct original inputs
reconstructed_inputs, log_det_inverse = inverse_fun(params, inputs)

assert np.array_equal(inputs, reconstructed_inputs)

We can construct a larger meta-transformation by composing a sequence of sub-transformations using flows.serial. The resulting transformation adheres to the exact same interface and is indistinguishable from any other regular transformation.

init_fun = flows.serial(
  flows.MADE(),
  flows.BatchNorm(),
  flows.Reverse()
)

params, direct_fun, inverse_fun = init_fun(rng, input_shape)

Distributions

A distribution has a similarly simple interface. It is characterized by a set of parameters, a function for querying the log of the pdf at a given point, and a sampling function.

init_fun = Normal()

params, log_pdf, sample = init_fun(rng, input_shape)

log_pdfs = log_pdf(params, inputs)

samples = sample(rng, params, num_samples)

Normalizing Flow Models

Under this definition, a normalizing flow model is just a distribution. But to retrieve one, we have to give it a transformation and another prior distribution.

transformation = flows.serial(
  flows.MADE(),
  flows.BatchNorm(),
  flows.Reverse(),
  flows.MADE(),
  flows.BatchNorm(),
  flows.Reverse(),
)

prior = Normal()

init_fun = flows.Flow(transformation, prior)

params, log_pdf, sample = init_fun(rng, input_shape)

How do I train a model?

To train our model, we would typically define an appropriate loss function and parameter update step.

def loss(params, inputs):
  return -log_pdf(params, inputs).mean()

@jit
def step(i, opt_state, inputs):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, inputs), opt_state)

Given these, we can go forward and execute a standard JAX training loop.

batch_size = 32

itercount = itertools.count()
for epoch in range(num_epochs):
  npr.shuffle(X)
  for batch_index in range(0, len(X), batch_size):
    opt_state = step(next(itercount), opt_state, X[batch_index:batch_index+batch_size])

optimized_params = get_params(opt_state)

Now that we have our trained model parameters, we can query and sample as regular.

log_pdfs = log_pdf(optimized_params, inputs)

samples = sample(rng, optimized_params, num_samples)

Magic!

Interested in contributing?

Yay! Check out our contributing guidelines in .github/CONTRIBUTING.md.

Inspiration

This repository is largely modeled after the pytorch-flows repository by Ilya Kostrikov .

The implementations are modeled after the work of the following papers.

Density estimation using Real NVP
Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio
arXiv:1605.08803

Glow: Generative Flow with Invertible 1x1 Convolutions
Diederik P. Kingma, Prafulla Dhariwal
arXiv:1807.03039

Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design
Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel
OpenReview:Hyg74h05tX

Masked Autoregressive Flow for Density Estimation
George Papamakarios, Theo Pavlakou, Iain Murray
arXiv:1705.07057

Project details


Release history Release notifications

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for jax-flows, version 0.0.0
Filename, size File type Python version Upload date Hashes
Filename, size jax_flows-0.0.0-py3-none-any.whl (8.4 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size jax-flows-0.0.0.tar.gz (7.3 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page