Normalizing flow implementations in jax.
Project description
FlowJax: Normalising Flows in Jax
Documentation
Available here.
Short example
Training a flow can be done in a few lines of code:
from flowjax.flows import BlockNeuralAutoregressiveFlow
from flowjax.train_utils import train_flow
from flowjax.distributions import Normal
from jax import random
import jax.numpy as jnp
data_key, flow_key, train_key = random.split(random.PRNGKey(0), 3)
x = random.uniform(data_key, (10000, 3)) # Toy data
base_dist = Normal(jnp.zeros(x.shape[1]))
flow = BlockNeuralAutoregressiveFlow(flow_key, base_dist)
flow, losses = train_flow(train_key, flow, x, learning_rate=0.05)
# We can now evaluate the log-probability of arbitrary points
flow.log_prob(x)
The package currently supports the following:
CouplingFlow
(Dinh et al., 2017) andMaskedAutoregressiveFlow
(Papamakarios et al., 2017) conditioner architectures- Common "transformers", such as
AffineTransformer
andRationalQuadraticSplineTransformer
(the latter used in neural spline flows; Durkan et al., 2019) BlockNeuralAutoregressiveFlow
, as introduced by De Cao et al., 2019TriangularSplineFlow
, introduced here.
Installation
pip install flowjax
Warning
This package is new and may have substantial breaking changes between major releases.
TODO
A few limitations / things that could be worth including in the future:
- Add documentation
- Support varied "event" dimensions:
- i.e. allow
x
andcondition
instances to havendim==0
(scalar), orndim > 1
. - Chaining of bijections with varied event
ndim
could follow numpy-like broadcasting rules. - Allow vmap-like transform to define bijections with expanded event dimensions.
- i.e. allow
- Training script for variational inference
- Define transformers by wrapping a bijection?
Related
We make use of the Equinox package, which facilitates object-oriented programming with Jax.
FAQ
How to avoid training the base distribution?
Provide a filter_spec
to train_flow
, for example
import equinox as eqx
import jax.tree_util as jtu
filter_spec = jtu.tree_map(lambda x: eqx.is_inexact_array(x), flow)
filter_spec = eqx.tree_at(lambda tree: tree.base_dist, filter_spec, replace=False)
Do I need to scale my variables?
In general yes, you should consider the form and scales of the target samples. Often it is useful to define a bijection to carry out the preprocessing, then to transform the flow with the inverse, to "undo" the preprocessing. For example, to carry out "standard scaling", we could do
import jax
from flowjax.bijections import Affine, Invert
from flowjax.distributions import Transformed
preprocess = Affine(-x.mean(axis=0)/x.std(axis=0), 1/x.std(axis=0))
x_processed = jax.vmap(preprocess.transform)(x)
flow, losses = train_flow(train_key, flow, x_processed)
flow = Transformed(flow, Invert(preprocess)) # "undo" the preprocessing
Do I need to JIT things?
The methods of distributions and bijections are not jitted by default. For example, if you wanted to sample several batches after training, then it is usually worth using jit
import equinox as eqx
batch_size = 256
keys = random.split(random.PRNGKey(0), 5)
# Often slow - sample not jitted!
results = []
for batch_key in keys:
x = flow.sample(batch_key, n=batch_size)
results.append(x)
# Fast - sample jitted!
results = []
for batch_key in keys:
x = eqx.filter_jit(flow.sample)(batch_key, n=batch_size)
results.append(x))
Authors
flowjax
was written by Daniel Ward <danielward27@outlook.com>
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.