Easy to use distributions, bijections and normalizing flows in JAX.
Project description
FlowJax: Normalizing Flows in Jax
Documentation
Available here.
Short example
Training a flow can be done in a few lines of code:
from flowjax.flows import BlockNeuralAutoregressiveFlow
from flowjax.train import fit_to_data
from flowjax.distributions import Normal
from jax import random
import jax.numpy as jnp
data_key, flow_key, train_key = random.split(random.PRNGKey(0), 3)
x = random.uniform(data_key, (10000, 3)) # Toy data
base_dist = Normal(jnp.zeros(x.shape[1]))
flow = BlockNeuralAutoregressiveFlow(flow_key, base_dist)
flow, losses = fit_to_data(train_key, flow, x, learning_rate=0.05)
# We can now evaluate the log-probability of arbitrary points
flow.log_prob(x)
The package currently includes:
- Many simple bijections and distributions, implemented as Equinox modules.
CouplingFlow
(Dinh et al., 2017) andMaskedAutoregressiveFlow
(Papamakarios et al., 2017) normalizing flow architectures.- These can be used with arbitrary bijections as transformers, such as
Affine
orRationalQuadraticSpline
(the latter used in neural spline flows; Durkan et al., 2019).
- These can be used with arbitrary bijections as transformers, such as
BlockNeuralAutoregressiveFlow
, as introduced by De Cao et al., 2019TriangularSplineFlow
, introduced here.- Training scripts for fitting by maximum likelihood, variational inference, or using contrastive learning for sequential neural posterior estimation (Greenberg et al., 2019; Durkan et al., 2020)
Installation
pip install flowjax
Development
We can install a version for development as follows
git clone https://github.com/danielward27/flowjax.git
cd flowjax
pip install -e .[dev]
sudo apt-get install pandoc # Required for building documentation
Warning
This package is new and may have substantial breaking changes between major releases.
TODO
A few limitations / things that could be worth including in the future:
- Add ability to "reshape" bijections.
Related
We make use of the Equinox package, which facilitates object-oriented programming with Jax.
Authors
flowjax
was written by Daniel Ward <danielward27@outlook.com>
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
flowjax-10.0.4.tar.gz
(750.2 kB
view hashes)
Built Distribution
flowjax-10.0.4-py3-none-any.whl
(43.8 kB
view hashes)