Normalizing Flows using Jax
Project description
NuX - Normalizing Flows using JAX
What is NuX?
NuX is a library for building normalizing flows using JAX.
What are normalizing flows?
Normalizing flows learn a parametric model over an unknown probability density function using data. We assume that data points are sampled i.i.d. from an unknown distribution p(x). Normalizing flows learn a parametric approximation of the true data distribution, q(x), using maximum likelihood learning. The resulting q(x) can be efficiently sampled from and evaluated exactly.
Why use NuX?
NuX has many normalizing flow layers implemented with an easy to use interface.
import nux.flows as nux
import jax
from jax import random
import jax.numpy as jnp
key = random.PRNGKey(0)
# Build a dummy dataset
x_train, x_test = jnp.ones((2, 100, 4))
# Build a simple normalizing flow
init_fun = nux.sequential(nux.Coupling(),
nux.ActNorm(),
nux.UnitGaussianPrior())
# Perform data-dependent initialization
_, flow = init_fun(key, {'x': x_train}, batched=True)
# Run data through the flow
inputs = {'x': x_test}
outputs, _ = flow.apply(flow.params, flow.state, inputs)
z, log_likelihood = outputs['x'], outputs['log_pz'] + outputs['log_det']
# Check the reconstructions
reconst, _ = flow.apply(flow.params, flow.state, {'x': z}, reverse=True)
assert jnp.allclose(x_test, reconst['x'])
What's implemented?
Check out the bijective, injective and surjective transformations that are implemented. Any contributions are welcome!
How does it work?
The modularity of normalizing flows allows us to construct complex flows using code that is made simple using JAX.
Create complex flows
Flow layers can be imported from nux.flows and chained together sequentially using nux.flows.sequential
. For more complex flows, we can split the flow using the probability chain rule in nux.flows.ChainRule
and then run multiple flows in parallel using nux.flows.parallel
.
Using these basic transformations, we can easily create complex flows. For example, a multiscale GLOW normalizing flow can be implemented easily:
import nux.flows as nux
def multi_scale(flow, existing_flow):
# This is implemented in nux.flows.compose
return nux.sequential(flow,
nux.Squeeze(),
nux.ChainRule(2, factor=True),
nux.factored(existing_flow, nux.Identity()),
nux.ChainRule(2, factor=False),
nux.UnSqueeze())
def GLOWBlock():
return nux.sequential(nux.ActNorm(),
nux.OnebyOneConv(),
nux.Coupling(n_channels=512))
def GLOW(num_blocks=4):
layers = [GLOWBlock() for _ in range(num_blocks)]
return nux.sequential(*layers)
def MultiscaleGLOW(quantize_bits=3):
flow = nux.Identity()
flow = multi_scale(GLOW(), flow)
flow = multi_scale(GLOW(), flow)
flow = multi_scale(GLOW(), flow)
flow = multi_scale(GLOW(), flow)
flow = nux.sequential(nux.UniformDequantization(scale=2**quantize_bits),
nux.Logit(),
nux.Squeeze(), # So that the channel is divisible by 2
flow,
nux.Flatten(),
nux.AffineGaussianPriorDiagCov(out_dim=128)) # Use a low dimensional prior for best results!
flow_init_fun = flow # The result of creating layers is an initializer
return flow_init_fun
Initialize your flow with data
NuX initializes flows using data to infer the input and output shapes at each flow layer and to help initialize layers like Actnorm. NuX uses dictionaries as the primary data-structure to pass data between flow layers.
import jax
import jax.numpy as jnp
flow_init_fun = MultiscaleGLOW()
key = jax.random.PRNGKey(0)
inputs = {'x': jnp.zeros(64, 32, 32, 3)} # Create a dummy dataset
outputs, flow = flow_init_fun(key, inputs, batched=True) # Must specify if the input data is batched or not
Inputs/Outputs
Each flow application expects an input dictionary. The key 'x' should correspond to the data that is passed between flow layers. Furthermore all elements of the input dictionary are passed to each flow layer. For example, in classification we can pass labels to any layer using:
inputs = {'x': data, 'y': labels}
Every flow layer returns a dictionary of outputs that contains the transformed data in key 'x' and a log likelihood contribution term. For standard transformations, the log likelihood contribution term is under 'log_det' and 'log_pz' for priors. Like the inputs, flow layers can also output other key value pairs.
Flow data structure
The second value returned by an initializer call is the flow data structure. This data structure contains the name of a layer and dictionaries of the input/output shapes/dims, parameters, state and apply function. The shapes/dims aid auto-batching while the parameters and state parametrize the flow. The difference between parameters and state is that parameters is intended to contain the parameters that will be trained with gradient descent while the state values do not (like the running statistics in batch normalization). jax.tree_util.tree_map
and jax.tree_util.tree_multimap
are your friend when working with dictionaries!
The apply function is called with the parameters, state, inputs, keyword arguments and a flag that specifies which direction to run the flow:
# Run the flow forwards (x -> z)
outputs, updated_state = flow.apply(flow.params, flow.state, inputs, key=key, reverse=False)
log_px = outputs['log_pz'] + outputs['log_det']
# Run the flow forwards in reverse (z -> x)
reconstr_inputs, _ = flow.apply(flow.params, flow.state, outputs, key=key, reverse=True)
Use Haiku to create deep flow layers
Flow layers, like nux.flows.Coupling
, can use a neural network to introduce complex non-linearities. These neural networks must constructed using Haiku. There are default networks that flow layers default to, but any Haiku network can be used. For example, we can construct a transformation for image coupling layers as follows:
class SimpleConv(hk.Module):
def __init__(self, out_shape, n_hidden_channels, name=None):
super().__init__(name=name)
_, _, out_channels = out_shape
self.out_channels = out_channels
self.n_hidden_channels = n_hidden_channels
self.last_channels = 2*out_channels
def __call__(self, x, **kwargs):
H, W, C = x.shape # NuX ensures that the input will be unbatched!
x = hk.Conv2D(output_channels=self.n_hidden_channels,
kernel_shape=(3, 3),
stride=(1, 1))(x[None])[0]
x = jax.nn.relu(x)
x = hk.Conv2D(output_channels=self.n_hidden_channels,
kernel_shape=(1, 1),
stride=(1, 1))(x[None])[0]
x = jax.nn.relu(x)
x = hk.Conv2D(output_channels=self.last_channels,
kernel_shape=(3, 3),
stride=(1, 1),
w_init=hk.initializers.Constant(0),
b_init=hk.initializers.Constant(0))(x[None])[0]
mu, alpha = jnp.split(x, 2, axis=-1)
alpha = jnp.tanh(alpha)
return mu, alpha
NuX handles batching internally, so every flow and network is guaranteed to be passed an unbatched input.
Using JAX
NuX is built using JAX so all of its features can be used on a flow.
Creating custom flow layers
Custom flows are easy to create in NuX. NuX can internally handle batch dimensions, so custom layers can be implemented assuming the input is unbatched:
import jax.numpy as jnp
import jax.nn.initializers as jaxinit
import nux.flows.base as base
@base.auto_batch # Ensure that apply_fun receives unbatched inputs.
def OnebyOneConvDense(W_init=jaxinit.glorot_normal(), name='1x1conv_dense'):
def apply_fun(params, state, inputs, reverse=False, **kwargs):
x = inputs['x'] # Unpack the inputs
W = params['W'] # Unpack the parameters
height, width, channel = x.shape # auto_batch ensures x is unbatched!
# Compute the transformation
if(reverse == False):
z = jnp.einsum('ij,hwj->hwi', W, x)
else:
W_inv = jnp.linalg.inv(W)
z = jnp.einsum('ij,hwj->hwi', W_inv, x)
# Compute the log Jacobian determinant
log_det = jnp.linalg.slogdet(W)[1]
log_det *= height*width
# Return the outputs and update the state if necessary.
outputs = {'x': z, 'log_det': log_det}
updated_state = state
return outputs, updated_state
def create_params_and_state(key, input_shapes):
# The data_dependent=False flag below ensures that input_shapes is unbatched.
height, width, channel = input_shapes['x']
# Initialize the parameters
W = W_init(key, (channel, channel))
params, state = {'W': W}, {}
return params, state
return base.initialize(name, apply_fun, create_params_and_state, data_dependent=False) # Helper to put everything together.
Under the hood, base.initialize
extracts the shapes from the input to the initializer. These shapes are passed to create_params_and_state
to generate the parameters and state. The inputs, parameters and state are then passed to apply_fun
to compute the outputs. Finally, the shapes/dimensions of the outputs are retrieved and stored.
At runtime, base.auto_batch
has access to the unbatched input dimensions for each flow. With this information, it recursively applies jax.vmap
to correctly handle nested batching.
The data_dependent
flag can be set to in order to pass the batched inputs to create_params_and_state
. For Actnorm
, this looks like:
@base.auto_batch
def ActNorm(log_s_init=jaxinit.zeros, b_init=jaxinit.zeros, name='act_norm'):
multiply_by = None # We can store initialize time constants in this outer scope.
def apply_fun(params, state, inputs, reverse=False, **kwargs):
if(reverse == False):
z = (inputs['x'] - params['b'])*jnp.exp(-params['log_s'])
else:
z = jnp.exp(params['log_s'])*inputs['x'] + params['b']
log_det = -params['log_s'].sum()*multiply_by
outputs = {'x': z, 'log_det': log_det}
return outputs, state
def create_params_and_state(key, inputs, batch_depth):
# The shape of x is the unbatched shape of x prepended with batch_depth dimensions.
x = inputs['x']
# Need to keep track of the dimensionality of all but the last axis in case we pass in an image.
nonlocal multiply_by
multiply_by = jnp.prod([s for i, s in enumerate(x.shape) if i >= batch_depth and i < len(x.shape) - 1])
# Create the parameters using the batch of data
axes = tuple(jnp.arange(len(x.shape) - 1))
params = {'b' : jnp.mean(x, axis=axes),
'log_s': jnp.log(jnp.std(x, axis=axes) + 1e-5)}
state = {}
return params, state
return base.initialize(name, apply_fun, create_params_and_state, data_dependent=True)
Testing a custom flow
nux.tests.nf_test.flow_test
is a simple function to test the correctness of a flow. It checks unbatched/batched/doubly-batched reconstructions by running a flow forwards then in reverse and checks the log Jacobian determinant against the brute force solution computed using jax.jacobian
.
init_fun = Flow()
unbatched_inputs = {'x': data}
flow_test(init_fun, unbatched_inputs, key)
Installation
For the moment, NuX only works with python 3.7. The steps to install are:
pip install nux
pip install git+https://github.com/deepmind/dm-haiku
If you want GPU support for JAX, follow the intructions here https://github.com/google/jax#installation
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.