Normalizing Flows for JAX

## Project description # Normalizing Flows in JAX

Implementations of normalizing flows (RealNVP, GLOW, MAF) in the JAX deep learning framework.    ## What are normalizing flows?

Normalizing flow models are generative models. That is, they infer the probability distribution of a given dataset. With that distribution we can do a number of interesting things, namely query the likelihood of given points as well as sample new realistic points.

## How are things structured?

### Transformations

A `transformation` is a parameterized invertible function.

```init_fun = flows.MADE()

params, direct_fun, inverse_fun = init_fun(rng, input_shape)

# Transform some inputs
transformed_inputs, log_det_direct = direct_fun(params, inputs)

# Reconstruct original inputs
reconstructed_inputs, log_det_inverse = inverse_fun(params, inputs)

assert np.array_equal(inputs, reconstructed_inputs)
```

We can construct a larger meta-transformation by composing a sequence of sub-transformations using `flows.serial`. The resulting transformation adheres to the exact same interface and is indistinguishable from any other regular transformation.

```init_fun = flows.serial(
flows.BatchNorm(),
flows.Reverse()
)

params, direct_fun, inverse_fun = init_fun(rng, input_shape)
```

### Distributions

A `distribution` has a similarly simple interface. It is characterized by a set of parameters, a function for querying the log of the pdf at a given point, and a sampling function.

```init_fun = Normal()

params, log_pdf, sample = init_fun(rng, input_shape)

log_pdfs = log_pdf(params, inputs)

samples = sample(rng, params, num_samples)
```

### Normalizing Flow Models

Under this definition, a normalizing flow model is just a `distribution`. But to retrieve one, we have to give it a transformation and another prior distribution.

```transformation = flows.serial(
flows.BatchNorm(),
flows.Reverse(),
flows.BatchNorm(),
flows.Reverse(),
)

prior = Normal()

init_fun = flows.Flow(transformation, prior)

params, log_pdf, sample = init_fun(rng, input_shape)
```

### How do I train a model?

To train our model, we would typically define an appropriate loss function and parameter update step.

```def loss(params, inputs):
return -log_pdf(params, inputs).mean()

@jit
def step(i, opt_state, inputs):
params = get_params(opt_state)
```

Given these, we can go forward and execute a standard JAX training loop.

```batch_size = 32

itercount = itertools.count()
for epoch in range(num_epochs):
npr.shuffle(X)
for batch_index in range(0, len(X), batch_size):
opt_state = step(next(itercount), opt_state, X[batch_index:batch_index+batch_size])

optimized_params = get_params(opt_state)
```

Now that we have our trained model parameters, we can query and sample as regular.

```log_pdfs = log_pdf(optimized_params, inputs)

samples = sample(rng, optimized_params, num_samples)
```

Magic!

## Interested in contributing?

Yay! Check out our contributing guidelines in `.github/CONTRIBUTING.md`.

## Inspiration

This repository is largely modeled after the `pytorch-flows` repository by Ilya Kostrikov .

The implementations are modeled after the work of the following papers.

Density estimation using Real NVP
Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio
arXiv:1605.08803

Glow: Generative Flow with Invertible 1x1 Convolutions
Diederik P. Kingma, Prafulla Dhariwal
arXiv:1807.03039

Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design
Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel
OpenReview:Hyg74h05tX

Masked Autoregressive Flow for Density Estimation
George Papamakarios, Theo Pavlakou, Iain Murray
arXiv:1705.07057

## Project details

This version 0.0.0