Normalizing Flows for JAX
Project description
Normalizing Flows in JAX
Implementations of normalizing flows (RealNVP, GLOW, MAF) in the JAX deep learning framework.
What are normalizing flows?
Normalizing flow models are generative models. That is, they infer the probability distribution of a given dataset. With that distribution we can do a number of interesting things, namely query the likelihood of given points as well as sample new realistic points.
How are things structured?
Transformations
A transformation
is a parameterized invertible function.
init_fun = flows.MADE()
params, direct_fun, inverse_fun = init_fun(rng, input_shape)
# Transform some inputs
transformed_inputs, log_det_direct = direct_fun(params, inputs)
# Reconstruct original inputs
reconstructed_inputs, log_det_inverse = inverse_fun(params, inputs)
assert np.array_equal(inputs, reconstructed_inputs)
We can construct a larger meta-transformation by composing a sequence of sub-transformations using flows.serial
. The resulting transformation adheres to the exact same interface and is indistinguishable from any other regular transformation.
init_fun = flows.serial(
flows.MADE(),
flows.BatchNorm(),
flows.Reverse()
)
params, direct_fun, inverse_fun = init_fun(rng, input_shape)
Distributions
A distribution
has a similarly simple interface. It is characterized by a set of parameters, a function for querying the log of the pdf at a given point, and a sampling function.
init_fun = Normal()
params, log_pdf, sample = init_fun(rng, input_shape)
log_pdfs = log_pdf(params, inputs)
samples = sample(rng, params, num_samples)
Normalizing Flow Models
Under this definition, a normalizing flow model is just a distribution
. But to retrieve one, we have to give it a transformation and another prior distribution.
transformation = flows.serial(
flows.MADE(),
flows.BatchNorm(),
flows.Reverse(),
flows.MADE(),
flows.BatchNorm(),
flows.Reverse(),
)
prior = Normal()
init_fun = flows.Flow(transformation, prior)
params, log_pdf, sample = init_fun(rng, input_shape)
How do I train a model?
To train our model, we would typically define an appropriate loss function and parameter update step.
def loss(params, inputs):
return -log_pdf(params, inputs).mean()
@jit
def step(i, opt_state, inputs):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, inputs), opt_state)
Given these, we can go forward and execute a standard JAX training loop.
batch_size = 32
itercount = itertools.count()
for epoch in range(num_epochs):
npr.shuffle(X)
for batch_index in range(0, len(X), batch_size):
opt_state = step(next(itercount), opt_state, X[batch_index:batch_index+batch_size])
optimized_params = get_params(opt_state)
Now that we have our trained model parameters, we can query and sample as regular.
log_pdfs = log_pdf(optimized_params, inputs)
samples = sample(rng, optimized_params, num_samples)
Magic!
Interested in contributing?
Yay! Check out our contributing guidelines in .github/CONTRIBUTING.md
.
Inspiration
This repository is largely modeled after the pytorch-flows
repository by Ilya Kostrikov
.
The implementations are modeled after the work of the following papers.
Density estimation using Real NVP
Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio
arXiv:1605.08803
Glow: Generative Flow with Invertible 1x1 Convolutions
Diederik P. Kingma, Prafulla Dhariwal
arXiv:1807.03039
Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design
Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel
OpenReview:Hyg74h05tX
Masked Autoregressive Flow for Density Estimation
George Papamakarios, Theo Pavlakou, Iain Murray
arXiv:1705.07057
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for jax_flows-0.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e1c272a488bc794078bbdd07b5b0ec7e1b8b3af81e3b7169d54a7cbcefb15025 |
|
MD5 | 56fb4bbbf2bd2824e53b2ef6f06933e7 |
|
BLAKE2b-256 | 58534b8479223cdae3d998afb014f0b61b6dc695a4a0ad90c501b3edc873aafd |