Skip to main content

Distrax: Probability distributions in JAX.

Project description

Distrax

CI status

Distrax is a lightweight library of probability distributions and bijectors. It acts as a JAX-native reimplementation of a subset of TensorFlow Probability (TFP), with some new features and emphasis on extensibility.

Installation

Distrax can be installed with pip directly from GitHub:

pip install git+git://github.com/deepmind/distrax.git.

or from PyPI:

pip install distrax

To run the tests or examples you will need to install additional requirements.

Design Principles

The general design principles for the DeepMind JAX Ecosystem are addressed in this blog. Additionally, Distrax places emphasis on the following:

  1. Readability. Distrax implementations are intended to be self-contained and read as close to the underlying math as possible.
  2. Extensibility. We have made it as simple as possible for users to define their own distribution or bijector. This is useful for example in reinforcement learning, where users may wish to define custom behavior for probabilistic agent policies.
  3. Compatibility. Distrax is not intended as a replacement for TFP, and TFP contains many advanced features that we do not intend to replicate. To this end, we have made the APIs for distributions and bijectors as cross-compatible as possible, and provide utilities for transforming between equivalent Distrax and TFP classes.

Features

Distributions

Distributions in Distrax are simple to define and use, particularly if you're used to TFP. Let's compare the two side-by-side:

import distrax
import jax
import jax.numpy as jnp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])

dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)
dist_tfp = tfd.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(seed=key)

# Both print 1.775
print(dist_distrax.log_prob(samples))
print(dist_tfp.log_prob(samples))

In addition to behaving consistently, Distrax distributions and TFP distributions are cross-compatible. For example:

mu_0 = jnp.array([-1., 0., 1.])
sigma_0 = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu_0, sigma_0)

mu_1 = jnp.array([1., 2., 3.])
sigma_1 = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.MultivariateNormalDiag(mu_1, sigma_1)

# Both print 85.237
print(dist_distrax.kl_divergence(dist_tfp))
print(tfd.kl_divergence(dist_distrax, dist_tfp))

Distrax distributions implement the method sample_and_log_prob, which provides samples and their log-probability in one line. For some distributions, this is more efficient than calling separately sample and log_prob:

mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(seed=key, sample_shape=())
log_prob = dist_distrax.log_prob(samples)

# A one-line equivalent of the above is:
samples, log_prob = dist_distrax.sample_and_log_prob(seed=key, sample_shape=())

TFP distributions can be passed to Distrax meta-distributions as inputs. For example:

key = jax.random.PRNGKey(1234)

mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.Normal(mu, sigma)

metadist_distrax = distrax.Independent(dist_tfp, reinterpreted_batch_ndims=1)
samples = metadist_distrax.sample(seed=key)
print(metadist_distrax.log_prob(samples))  # Prints 0.38871175

To use Distrax distributions in TFP meta-distributions, Distrax provides the wrapper to_tfp. A wrapped Distrax distribution can be directly used in TFP:

key = jax.random.PRNGKey(1234)

distrax_dist = distrax.Normal(0., 1.)
wrapped_dist = distrax.to_tfp(distrax_dist)
metadist_tfp = tfd.Sample(wrapped_dist, sample_shape=[3])

samples = metadist_tfp.sample(seed=key)
print(metadist_tfp.log_prob(samples))  # Prints -3.3409896

Bijectors

A "bijector" in Distrax is an invertible function that knows how to compute its Jacobian determinant. Bijectors can be used to create complex distributions by transforming simpler ones. Distrax bijectors are functionally similar to TFP bijectors, with a few API differences. Here is an example comparing the two:

import distrax
import jax.numpy as jnp

from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

# Same distribution.
distrax.Transformed(distrax.Normal(loc=0., scale=1.), distrax.Tanh())
tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), tfb.Tanh())

Additionally, Distrax bijectors can be composed and inverted:

bij_distrax = distrax.Tanh()
bij_tfp = tfb.Tanh()

# Same bijector.
inv_bij_distrax = distrax.Inverse(bij_distrax)
inv_bij_tfp = tfb.Invert(bij_tfp)

# These are both the identity bijector.
distrax.Chain([bij_distrax, inv_bij_distrax])
tfb.Chain([bij_tfp, inv_bij_tfp])

All TFP bijectors can be passed to Distrax, and can be freely composed with Distrax bijectors. For example, all of the following will work:

distrax.Inverse(tfb.Tanh())

distrax.Chain([tfb.Tanh(), distrax.Tanh()])

distrax.Transformed(tfd.Normal(loc=0., scale=1.), tfb.Tanh())

Distrax bijectors can also be passed to TFP, but first they must be transformed with to_tfp:

bij_distrax = distrax.to_tfp(distrax.Tanh())

tfb.Invert(bij_distrax)

tfb.Chain([tfb.Tanh(), bij_distrax])

tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), bij_distrax)

Distrax also comes with Lambda, a convenient wrapper for turning simple JAX functions into bijectors. Here are a few Lambda examples with their TFP equivalents:

distrax.Lambda(lambda x: x)
# tfb.Identity()

distrax.Lambda(lambda x: 2*x + 3)
# tfb.Chain([tfb.Shift(3), tfb.Scale(2)])

distrax.Lambda(jnp.sinh)
# tfb.Sinh()

distrax.Lambda(lambda x: jnp.sinh(2*x + 3))
# tfb.Chain([tfb.Sinh(), tfb.Shift(3), tfb.Scale(2)])

Unlike TFP, bijectors in Distrax do not take event_ndims as an argument when they compute the Jacobian determinant. Instead, Distrax assumes that the number of event dimensions is statically known to every bijector, and uses Block to lift bijectors to a different number of dimensions. For example:

x = jnp.zeros([2, 3, 4])

# In TFP, `event_ndims` can be passed to the bijector.
bij_tfp = tfb.Tanh()
ld_1 = bij_tfp.forward_log_det_jacobian(x, event_ndims=0)  # Shape = [2, 3, 4]

# Distrax assumes `Tanh` is a scalar bijector by default.
bij_distrax = distrax.Tanh()
ld_2 = bij_distrax.forward_log_det_jacobian(x)  # ld_1 == ld_2

# With `event_ndims=2`, TFP sums the last 2 dimensions of the log det.
ld_3 = bij_tfp.forward_log_det_jacobian(x, event_ndims=2)  # Shape = [2]

# Distrax treats the number of dimensions statically.
bij_distrax = distrax.Block(bij_distrax, ndims=2)
ld_4 = bij_distrax.forward_log_det_jacobian(x)  # ld_3 == ld_4

Distrax bijectors implement the method forward_and_log_det (some bijectors additionally implement inverse_and_log_det), which allows to obtain the forward mapping and its log Jacobian determinant in one line. For some bijectors, this is more efficient than calling separately forward and forward_log_det_jacobian. (Analogously, when available, inverse_and_log_det can be more efficient than inverse and inverse_log_det_jacobian.)

x = jnp.zeros([2, 3, 4])
bij_distrax = distrax.Tanh()

y = bij_distrax.forward(x)
ld = bij_distrax.forward_log_det_jacobian(x)

# A one-line equivalent of the above is:
y, ld = bij_distrax.forward_and_log_det(x)

Jitting Distrax

Distrax distributions and bijectors can be passed as arguments to jitted functions. User-defined distributions and bijectors get this property for free by subclassing distrax.Distribution and distrax.Bijector respectively. For example:

mu_0 = jnp.array([-1., 0., 1.])
sigma_0 = jnp.array([0.1, 0.2, 0.3])
dist_0 = distrax.MultivariateNormalDiag(mu_0, sigma_0)

mu_1 = jnp.array([1., 2., 3.])
sigma_1 = jnp.array([0.2, 0.3, 0.4])
dist_1 = distrax.MultivariateNormalDiag(mu_1, sigma_1)

jitted_kl = jax.jit(lambda d_0, d_1: d_0.kl_divergence(d_1))

# Both print 85.237
print(jitted_kl(dist_0, dist_1))
print(dist_0.kl_divergence(dist_1))

Subclassing Distributions and Bijectors

User-defined distributions can be created by subclassing distrax.Distribution. This can be achieved by implementing only a few methods:

class MyDistribution(distrax.Distribution):

  def __init__(self, ...):
    ...

  def _sample_n(self, key, n):
    samples = ...
    return samples

  def log_prob(self, value):
    log_prob = ...
    return log_prob

  def event_shape(self):
    event_shape = ...
    return event_shape

  def _sample_n_and_log_prob(self, key, n):
    # Optional. Only when more efficient implementation is possible.
    samples, log_prob = ...
    return samples, log_prob

Similarly, more complicated bijectors can be created by subclassing distrax.Bijector. This can be achieved by implementing only one or two class methods:

class MyBijector(distrax.Bijector):

  def __init__(self, ...):
    super().__init__(...)

  def forward_and_log_det(self, x):
    y = ...
    logdet = ...
    return y, logdet

  def inverse_and_log_det(self, y):
    # Optional. Can be omitted if inverse methods are not needed.
    x = ...
    logdet = ...
    return x, logdet

Examples

The examples directory contains some representative examples of full programs that use Distrax.

hmm.py demonstrates how to use distrax.HMM to combine distributions that model the initial states, transitions, and observation distributions of a Hidden Markov Model, and infer the latent rates and state transitions in a changing noisy signal.

vae.py contains an example implementation of a variational auto-encoder that is trained to model the binarized MNIST dataset as a joint distrax.Bernoulli distribution over the pixels.

flow.py illustrates a simple example of modelling MNIST data using distrax.MaskedCoupling layers to implement a normalizing flow, and training the model with gradient descent.

Acknowledgements

We greatly appreciate the ongoing support of the TensorFlow Probability authors in assisting with the design and cross-compatibility of Distrax.

Special thanks to Aleyna Kara and Kevin Murphy for contributing the code upon which the Hidden Markov Model and associated example are based.

Citing Distrax

To cite this repository:

@software{distrax2021github,
  author = {Jake Bruce and David Budden and Matteo Hessel and George Papamakarios and Francisco Ruiz},
  title = {Distrax: Probability distributions in {JAX}},
  url = {http://github.com/deepmind/distrax},
  version = {0.0.1},
  year = {2021},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

distrax-0.0.3.tar.gz (129.3 kB view details)

Uploaded Source

Built Distribution

distrax-0.0.3-py3-none-any.whl (221.1 kB view details)

Uploaded Python 3

File details

Details for the file distrax-0.0.3.tar.gz.

File metadata

  • Download URL: distrax-0.0.3.tar.gz
  • Upload date:
  • Size: 129.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for distrax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 b09e3cb70d05d6816788b32cda6c73e86e9668cdcce69379622ee1d03a904734
MD5 470d22fa23858f50b7d85d5251fbd6e4
BLAKE2b-256 e182d67de23b0d4c54b7d8ffe770538d53c2f65544205fc09ebd16d5c13e107f

See more details on using hashes here.

File details

Details for the file distrax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: distrax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 221.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for distrax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 ae838dd0e5da3403f90ee8b0a74ca91b1cc344038945fd71c1860104a645820a
MD5 2049f862ea2444a0759f8921ca468c5a
BLAKE2b-256 3d98231b7f79f4ec68f2c48a6402f6144a2a98e851aa42c133d4076e7a982903

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page