Skip to main content

Exponential families for JAX

Project description

https://badge.fury.io/py/efax.svg

This library provides a set of tools for working with exponential family distributions in the differential programming library JAX.

The exponential families are an important class of probability distributions that include the normal, gamma, beta, exponential, Poisson, binomial, and Bernoulli distributions. For an explanation of the fundamental ideas behind this library, see our overview on exponential families.

The main motivation for using EFAX over a library like tensorflow-probability or the basic functions in JAX is that EFAX provides the two most important parametrizations for each exponential family—the natural and expectation parametrizations—and a uniform interface to efficient implementations of the main functions used in machine learning. An example of why this matters is that the most efficient way to implement cross entropy between X and Y relies on X being in the expectation parametrization and Y in the natural parametrization.

Framework

Representation

EFAX has a single base class for its objects: Parametrization whose type encodes the distribution family.

Each parametrization object has a shape, and so it can store any number of distributions. When operating on such objects, NumPy’s broadcasting rules apply. This is unlike SciPy where each distribution is represented by a single object, and so a thousand distributions need a thousand objects.

All parametrization objects are dataclasses using tjax.dataclass. These dataclasses are a modification of Python’s dataclasses to support JAX’s “PyTree” type registration.

Each of the fields of a parametrization object stores a parameter over a specified support. Some parameters are marked as “fixed”, which means that they are fixed with respect to the exponential family. An example of a fixed parameter is the failure number of the negative binomial distribution.

For example:

@dataclass
class MultivariateNormalNP(NaturalParametrization['MultivariateNormalEP']):
    mean_times_precision: RealArray = distribution_parameter(VectorSupport())
    negative_half_precision: RealArray = distribution_parameter(SymmetricMatrixSupport())

In this case, we see that there are two natural parameters for the multivariate normal distribution. Objects of this type can hold any number of distributions: if such an object x has shape s, then the shape of x.mean_times_precision is (*s, n) and the shape of x.negative_half_precision is (*s, n, n).

Parametrizations

Each exponential family distribution has two special parametrizations: the natural and the expectation parametrization. (These are described in the overview pdf.) Consequently, every distribution has at least two base classes, one inheriting from NaturalParametrization and one from ExpectationParametrization.

The motivation for the natural parametrization is combining and scaling independent predictive evidence. In the natural parametrization, these operations correspond to scaling and addition.

The motivation for the expectation parametrization is combining independent observations into the maximum likelihood distribution that could have produced them. In the expectation parametrization, this is an expected value.

EFAX provides conversions between the two parametrizations through the NaturalParametrization.to_exp and ExpectationParametrization.to_nat methods.

Important methods

EFAX aims to provide the main methods used in machine learning.

Every Parametrization has methods to flatten and unflatten the parameters into a single array: flattened and unflattened. Typically, array-valued signals in a machine learning model would be unflattened into a distribution object, operated on, and then flattened before being sent back to the model. Flattening is careful with distributions with symmetric (or Hermitian) matrix-valued parameters. It only stores the upper triangular elements.

Every NaturalParametrization has methods:

  • sufficient_statistics to produce the sufficient statistics given an observation (used in maximum likelihood estimation),

  • pdf, which is the density or mass function,

  • fisher_information, which is the Fisher information matrix, and

  • entropy, which is the Shannon entropy.

Every ExpectationParametrization has methods:

  • cross_entropy that is an efficient cross entropy armed with a numerically optimized custom JAX gradient. This is possible because the gradient of the cross entropy is the difference of expectation parameters plus the expected carrier measure.

Numerical optimization

Because of the nature of the log-normalizer and carrier measure, some methods for some distributions require numerical optimization. These are the conversion from expectation parameters to natural ones, the entropy, and the cross entropy.

Distributions

EFAX supports the following distributions:

  • Bernoulli

  • beta

  • chi

  • chi-square

  • complex normal

  • Dirichlet

  • exponential

  • gamma

  • geometric

  • logarithmic

  • multinomial

  • multivariate normal

    • with arbitrary variance

    • with diagonal variance

    • with isotropic variance

    • with unit variance

  • negative binomial

  • normal

  • Poisson

  • Rayleigh

  • von Mises-Fisher

  • Weibull

Usage

Basic usage

A basic use of the two parametrizations:

from jax import numpy as jnp

from efax import BernoulliEP, BernoulliNP

# p is the expectation parameters of three Bernoulli distributions having probabilities 0.4, 0.5,
# and 0.6.
p = BernoulliEP(jnp.array([0.4, 0.5, 0.6]))

# q is the natural parameters of three Bernoulli distributions having log-odds 0, which is
# probability 0.5.
q = BernoulliNP(jnp.zeros(3))

print(p.cross_entropy(q))
# [0.6931472 0.6931472 0.6931472]

# q2 is natural parameters of Bernoulli distributions having a probability of 0.3.
p2 = BernoulliEP(0.3 * jnp.ones(3))
q2 = p2.to_nat()

print(p.cross_entropy(q2))
# [0.6955941  0.78032386 0.86505365]
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation with probability
# 0.4 better than the other observations.

Optimization

Using the cross entropy to iteratively optimize a prediction is simple:

from jax import grad, jit, lax
from jax import numpy as jnp

from efax import BernoulliEP, BernoulliNP


def cross_entropy_loss(p, q):
    return p.cross_entropy(q)


gce = jit(grad(cross_entropy_loss, 1))


def body_fun(q):
    return BernoulliNP(q.log_odds - gce(some_p, q).log_odds * 1e-4)


def cond_fun(q):
    return jnp.sum(gce(some_p, q).log_odds ** 2) > 1e-7


# some_p are expectation parameters of a Bernoulli distribution corresponding
# to probability 0.4.
some_p = BernoulliEP(jnp.array(0.4))

# some_q are natural parameters of a Bernoulli distribution corresponding to
# log-odds 0, which is probability 0.5.
some_q = BernoulliNP(jnp.array(0.0))

# Optimize the predictive distribution iteratively.
print(lax.while_loop(cond_fun, body_fun, some_q))
# Outputs the natural parameters that correspond to 0.4.

# Compare with the true value.
print(some_p.to_nat())

Contribution guidelines

Contributions are welcome!

It’s not hard to add a new distribution. The steps are:

  • Create an issue for the new distribution.

  • Solve for or research the equations needed to fill the blanks in the overview pdf, and put them in the issue. I’ll add them to the pdf for you.

  • Implement the natural and expectation parametrizations. This can either be done directly like in the Bernoulli distribution, or as a transformation of an existing exponential family like the Rayleigh distribution. If the conversion from the expectation to the natural parametrization has no analytical solution, then there’s a mixin that implements a numerical solution, which was used in the Dirichlet distribution.

  • Add the new distribution to the tests by adding it to create_info.)

Implementation should respect PEP8. The tests can be run using pytest . There are a few tools to clean and check the source:

  • isort .

  • pylint efax

  • flake8 efax

  • mypy efax

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

efax-1.2.12.tar.gz (27.0 kB view hashes)

Uploaded Source

Built Distribution

efax-1.2.12-py3-none-any.whl (43.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page