Exponential families for JAX
Project description
EFAX: Exponential Families in JAX
This library provides a set of tools for working with exponential family distributions in the differential programming library JAX.
The exponential families are an important class of probability distributions that include the normal, gamma, beta, exponential, Poisson, binomial, and Bernoulli distributions. For an explanation of the fundamental ideas behind this library, see our overview on exponential families.
The main motivation for using EFAX over a library like tensorflow-probability or the basic functions in JAX is that EFAX provides the two most important parametrizations for each exponential family—the natural and expectation parametrizations—and a uniform interface to efficient implementations of the main functions used in machine learning. An example of why this matters is that the most efficient way to implement cross entropy between X and Y relies on X being in the expectation parametrization and Y in the natural parametrization.
Framework
Representation
EFAX has a single base class for its objects: Distribution whose type encodes the
distribution family.
Each parametrization object has a shape, and so it can store any number of distributions. Operations on these objects are vectorized. This is unlike SciPy where each distribution is represented by a single object, and so a thousand distributions need a thousand objects, and corresponding calls to functions that operate on them.
All parametrization objects are dataclasses using tjax.dataclass. These dataclasses are
a modification of Python’s dataclasses to support JAX’s “PyTree” type registration.
Each of the fields of a parametrization object stores a parameter over a specified support. Some parameters are marked as “fixed”, which means that they are fixed with respect to the exponential family. An example of a fixed parameter is the failure number of the negative binomial distribution.
For example:
@dataclass
class MultivariateNormalNP(NaturalParametrization['MultivariateNormalEP']):
mean_times_precision: RealArray = distribution_parameter(VectorSupport())
negative_half_precision: RealArray = distribution_parameter(SymmetricMatrixSupport())
In this case, we see that there are two natural parameters for the multivariate normal distribution.
Objects of this type can hold any number of distributions: if such an object x has shape
s, then the shape of
x.mean_times_precision is (*s, n) and the shape of
x.negative_half_precision is (*s, n, n).
Parametrizations
Each exponential family distribution has two special parametrizations: the natural and the
expectation parametrization. (These are described in the overview pdf.)
Consequently, every distribution has at least two base classes, one inheriting from
NaturalParametrization and one from ExpectationParametrization.
The motivation for the natural parametrization is combining and scaling independent predictive evidence. In the natural parametrization, these operations correspond to scaling and addition.
The motivation for the expectation parametrization is combining independent observations into the maximum likelihood distribution that could have produced them. In the expectation parametrization, this is an expected value.
EFAX provides conversions between the two parametrizations through the
NaturalParametrization.to_exp and ExpectationParametrization.to_nat methods.
Some distributions also provide additional convenience parametrizations. For example, the normal
distribution offers NormalVP (variance parametrization) and NormalDP (deviation
parametrization), and the multivariate normal provides MultivariateNormalVP and
MultivariateDiagonalNormalVP. These are not exponential-family parametrizations but are
useful for constructing and interpreting distributions.
Important methods
EFAX aims to provide the main methods used in machine learning.
Every Distribution has:
shapeandndim, which support broadcasting, andindexing via
[], which slices all parameter arrays simultaneously.
Every NaturalParametrization has methods:
to_expto convert itself to expectation parameters,sufficient_statisticsto produce the sufficient statistics given an observation (used in maximum likelihood estimation),log_normalizer, the log partition function,carrier_measure, the base measure,pdfandlog_pdf, which are the density or mass function and its logarithm,fisher_information_diagonalandfisher_information_trace, which return the diagonal and trace of the Fisher information matrix stored as distribution objects,apply_fisher_information, which applies the Fisher information matrix to a vector of expectation parameters efficiently in a single VJP pass,jeffreys_prior_density, which returns the square root of the Fisher information determinant,characteristic_function, which evaluates the characteristic function of the sufficient statistics via analytic continuation of the log-normalizer, andkl_divergence, which is the KL divergence.
Every ExpectationParametrization has methods:
to_natto convert itself to natural parameters, andkl_divergence, which is the KL divergence.
Some parametrizations inherit from these interfaces:
HasConjugatePriorcan produce and recover the conjugate prior,HasGeneralizedConjugatePriorextends that with per-dimension pseudo-observation counts,Multidimensionaldistributions have an integer number ofdimensions, andSamplabledistributions support sampling.
Some parametrizations inherit from these public mixins:
HasEntropyis a distribution with aentropymethod,HasEntropyEPis an expectation parametrization with analytically tractable entropy andcross_entropy, andHasEntropyNPis a natural parametrization with analytically tractable entropy via the paired expectation parametrization.
Some parametrizations inherit from these private mixins:
ExpToNatimplements the conversion from expectation to natural parameters when no analytical solution is possible. It uses Newton’s method with a Jacobian to invert the gradient log-normalizer.TransformedNaturalParametrizationproduces a natural parametrization by relating it to an existing natural parametrization. And similarly forTransformedExpectationParametrization.
Joint distributions
JointDistribution, JointDistributionE, and JointDistributionN compose
multiple independent distributions into a single object. JointDistributionE holds
expectation parametrizations and implements HasEntropyEP; JointDistributionN
holds natural parametrizations. They support the same to_nat / to_exp
conversions as simple distributions.
Structure utilities
EFAX provides three classes that capture the static metadata of a distribution tree—its types, parameter names, and dimension information—without requiring a live instance. They form an inheritance hierarchy:
- Assembler
Stores a post-order traversal of a
Distributiontree (types, paths, dimensions) so that distributions can be reconstructed from raw parameter data without passing type information alongside arrays. Key methods:assemble(params)— rebuild aDistributionfrom a{path: array}mapping,coerce_from_distribution(q)— reinterpret q’s numeric values under this Assembler’s types,domain_support()— enumerate each leaf distribution’s parameter constraints,generate_random(xp, rng, shape, safety)— draw a random distribution with valid parameters, andto_nat()/to_exp()— return a copy whose types are all in natural or expectation form.
- Estimator (extends Assembler)
Adds maximum likelihood estimation by recording which parameters are fixed (held constant across observations) and which are free. Because the MLE for every exponential family equals the mean of the sufficient statistics, estimation reduces to a single call:
sufficient_statistics(x)— compute the sufficient statistics of observation x, with fixed parameters supplied automatically.
Create one with
Estimator.from_type(type_p, **fixed),Estimator.from_expectation(p), orEstimator.from_natural(p).- Flattener (extends Estimator)
Adds encoding and decoding between a
Distributionand an array of shape(*distribution.shape, k), making distributions compatible with neural networks and numerical optimizers. Fixed parameters are excluded from the encoded array and reinserted automatically on decode.Flattener.flatten(p, mapped_to_plane=True)— encode p into a(Flattener, array)pair,unflatten(array)— decode the array back into a distribution, andfinal_dimension_size()— the size k of the last axis of the encoded array.
The
mapped_to_planeflag controls whether constrained parameters (e.g., those on a simplex or restricted to the positive reals) are bijectively mapped to all of ℝⁿ. Set itTruewhen passing to a neural network (to prevent invalid outputs), andFalsewhen the raw magnitudes matter—for example when differencing expectation parameters or computing Jacobians.
Distributions
EFAX supports the following distributions:
normal:
univariate real:
with unit variance
with arbitrary parameters
univariate complex
with unit variance and zero pseudo-variance
with arbitrary parameters
multivariate real:
with unit variance
with fixed variance
with isotropic variance
with diagonal variance
with arbitrary parameters
multivariate complex:
with unit variance and zero pseudo-variance
circularly symmetric
softplus-transformed:
with unit variance
with arbitrary parameters
log-normal (exponential-transformed):
with unit variance
with arbitrary parameters
on a finite set:
Bernoulli
categorical
on the nonnegative integers:
geometric
logarithmic
negative binomial
Poisson
on the positive reals:
Rayleigh
Weibull
chi
chi-square
exponential
gamma
inverse Gaussian
inverse gamma
on the simplex:
beta
Dirichlet
generalized Dirichlet
on the n-sphere:
von Mises-Fisher
complex von Mises
on positive-definite matrices:
Wishart
Usage
Basic usage
A basic use of the two parametrizations:
"""Cross-entropy.
This example is based on section 1.4.1 from exponential_families.pdf, entitled Information
theoretic statistics.
"""
import jax.numpy as jnp
from tjax import print_generic
from efax import BernoulliEP, BernoulliNP
# p is the expectation parameters of three Bernoulli distributions having
# probabilities 0.4, 0.5, and 0.6.
p = BernoulliEP(jnp.asarray([0.4, 0.5, 0.6]))
# q is the natural parameters of three Bernoulli distributions having log-odds
# 0, which is probability 0.5.
q = BernoulliNP(jnp.zeros(3))
print_generic(p.cross_entropy(q))
# Jax Array (3,) float32
# └── 0.6931 │ 0.6931 │ 0.6931
# q2 is natural parameters of Bernoulli distributions having a probability of
# 0.3.
p2 = BernoulliEP(0.3 * jnp.ones(3))
q2 = p2.to_nat()
# A Bernoulli distribution with probability 0.3 predicts a Bernoulli observation
# with probability 0.4 better than the other observations.
print_generic(p.cross_entropy(q2))
# Jax Array (3,) float32
# └── 0.6956 │ 0.7803 │ 0.8651
Evidence combination:
"""Bayesian evidence combination.
This example is based on section 1.2.1 from exponential_families.pdf, entitled Bayesian
evidence combination.
Suppose you have a prior, and a set of likelihoods, and you want to combine all
of the evidence into one distribution.
"""
from operator import add
import jax.numpy as jnp
from tjax import print_generic
from efax import MultivariateDiagonalNormalVP, parameter_map
prior = MultivariateDiagonalNormalVP(mean=jnp.zeros(2),
variance=10 * jnp.ones(2))
likelihood = MultivariateDiagonalNormalVP(mean=jnp.asarray([1.1, -2.2]),
variance=jnp.asarray([3.0, 1.0]))
# Convert to the natural parametrization.
prior_np = prior.to_nat()
likelihood_np = likelihood.to_nat()
# Sum. We use parameter_map to ensure that we don't accidentally add "fixed"
# parameters, e.g., the failure count of a negative binomial distribution.
posterior_np = parameter_map(add, prior_np, likelihood_np)
# Convert to the source parametrization.
posterior = posterior_np.to_variance_parametrization()
print_generic({"prior": prior,
"likelihood": likelihood,
"posterior": posterior})
# dict
# ├── likelihood=MultivariateDiagonalNormalVP[dataclass]
# │ ├── mean=Jax Array (2,) float32
# │ │ └── 1.1000 │ -2.2000
# │ └── variance=Jax Array (2,) float32
# │ └── 3.0000 │ 1.0000
# ├── posterior=MultivariateDiagonalNormalVP[dataclass]
# │ ├── mean=Jax Array (2,) float32
# │ │ └── 0.8462 │ -2.0000
# │ └── variance=Jax Array (2,) float32
# │ └── 2.3077 │ 0.9091
# └── prior=MultivariateDiagonalNormalVP[dataclass]
# ├── mean=Jax Array (2,) float32
# │ └── 0.0000 │ 0.0000
# └── variance=Jax Array (2,) float32
# └── 10.0000 │ 10.0000
Optimization
Using the cross entropy to iteratively optimize a prediction is simple:
"""Optimization.
This example illustrates how this library fits in a typical machine learning
context. Suppose we have an unknown target value, and a loss function based on
the cross-entropy between the target value and a predictive distribution. We
will optimize the predictive distribution by a small fraction of its cotangent.
"""
import jax.numpy as jnp
from jax import grad, lax
from tjax import JaxBooleanArray, JaxRealArray, jit, print_generic
from efax import BernoulliEP, BernoulliNP, parameter_dot_product, parameter_map
def cross_entropy_loss(p: BernoulliEP, q: BernoulliNP) -> JaxRealArray:
return jnp.sum(p.cross_entropy(q))
gradient_cross_entropy = jit(grad(cross_entropy_loss, 1))
def apply(x: JaxRealArray, x_bar: JaxRealArray) -> JaxRealArray:
return x - 1e-4 * x_bar
def body_fun(q: BernoulliNP) -> BernoulliNP:
q_bar = gradient_cross_entropy(target_distribution, q)
return parameter_map(apply, q, q_bar)
def cond_fun(q: BernoulliNP) -> JaxBooleanArray:
q_bar = gradient_cross_entropy(target_distribution, q)
total = jnp.sum(parameter_dot_product(q_bar, q_bar))
return total > 1e-6 # noqa: PLR2004
# The target_distribution is represented as the expectation parameters of a
# Bernoulli distribution corresponding to probabilities 0.3, 0.4, and 0.7.
target_distribution = BernoulliEP(jnp.asarray([0.3, 0.4, 0.7]))
# The initial predictive distribution is represented as the natural parameters
# of a Bernoulli distribution corresponding to log-odds 0, which is probability
# 0.5.
initial_predictive_distribution = BernoulliNP(jnp.zeros(3))
# Optimize the predictive distribution iteratively.
predictive_distribution = lax.while_loop(cond_fun, body_fun,
initial_predictive_distribution)
# Compare the optimized predictive distribution with the target value in the
# same natural parametrization.
print_generic({"predictive_distribution": predictive_distribution,
"target_distribution": target_distribution.to_nat()})
# dict
# ├── predictive_distribution=BernoulliNP[dataclass]
# │ └── log_odds=Jax Array (3,) float32
# │ └── -0.8440 │ -0.4047 │ 0.8440
# └── target_distribution=BernoulliNP[dataclass]
# └── log_odds=Jax Array (3,) float32
# └── -0.8473 │ -0.4055 │ 0.8473
# Do the same in the expectation parametrization.
print_generic({"predictive_distribution": predictive_distribution.to_exp(),
"target_distribution": target_distribution})
# dict
# ├── predictive_distribution=BernoulliEP[dataclass]
# │ └── probability=Jax Array (3,) float32
# │ └── 0.3007 │ 0.4002 │ 0.6993
# └── target_distribution=BernoulliEP[dataclass]
# └── probability=Jax Array (3,) float32
# └── 0.3000 │ 0.4000 │ 0.7000
Maximum likelihood estimation
Maximum likelihood estimation often uses the conjugate prior, which can require exotic conjugate prior distributions to have been implemented. It is simpler to use the expectation parametrization instead.
"""Maximum likelihood estimation.
This example is based on section 1.3.2 from exponential_families.pdf, entitled Maximum
likelihood estimation.
Suppose you have some samples from a distribution family with unknown
parameters, and you want to estimate the maximum likelihood parameters of the
distribution.
"""
import jax.numpy as jnp
import jax.random as jr
from tjax import print_generic
from efax import DirichletEP, DirichletNP, Estimator, parameter_mean
# Consider a Dirichlet distribution with a given alpha.
alpha = jnp.asarray([2.0, 3.0, 4.0])
source_distribution = DirichletNP(alpha_minus_one=alpha - 1.0)
# Let's sample from it.
n_samples = 10000
key_a = jr.key(123)
samples = source_distribution.sample(key_a, (n_samples,))
# Now, let's find the maximum likelihood Dirichlet distribution that fits it.
# First, convert the samples to their sufficient statistics.
estimator = Estimator.from_type(DirichletEP)
ss = estimator.sufficient_statistics(samples)
# ss has type DirichletEP. This is similar to the conjugate prior of the
# Dirichlet distribution.
# Take the mean over the first axis.
ss_mean = parameter_mean(ss, axis=0) # ss_mean also has type DirichletEP.
# Convert this back to the natural parametrization.
estimated_distribution = ss_mean.to_nat()
print_generic({"estimated_distribution": estimated_distribution,
"source_distribution": source_distribution})
# dict
# ├── estimated_distribution=DirichletNP[dataclass]
# │ └── alpha_minus_one=Jax Array (3,) float32
# │ └── 0.9797 │ 1.9539 │ 2.9763
# └── source_distribution=DirichletNP[dataclass]
# └── alpha_minus_one=Jax Array (3,) float32
# └── 1.0000 │ 2.0000 │ 3.0000
Contribution guidelines
Contributions are welcome! I’m open to both new features, design ideas, and new distributions.
It’s not hard to add a new distribution. It’s usually around only one hundred lines of code. The steps are:
Create an issue for the new distribution.
Implement the natural and expectation parametrizations, either:
directly like in the Bernoulli distribution, or
as a transformation of an existing exponential family like the Rayleigh distribution.
Implement the conversion from the expectation to the natural parametrization. If this has no analytical solution, then there’s a mixin that implements a numerical solution. This can be seen in the Dirichlet distribution.
Add the new distribution to the tests by adding it to create_info.
The implementation should be consistent with the surrounding style, be type annotated, and pass the linters below.
The tests can be run using pytest -n auto. Specific distributions can be run with
pytest -n auto --distribution=Gamma where the names match the class names in
create_info.
There are a few tools to clean and check the source:
uv run ruff checkuv run ruff formatuv run ty check
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file efax-2.2.0.tar.gz.
File metadata
- Download URL: efax-2.2.0.tar.gz
- Upload date:
- Size: 51.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
467f502658a8d2d5dd34570c8f8431b1fea8b85ead1596dbbafaa86effaf463e
|
|
| MD5 |
332ac1e55ff18eb7096c1cdffd42d991
|
|
| BLAKE2b-256 |
51f5f88072dae34a23335ac294dfb27e31a3768c95d11dde75883488fdb2ab49
|
Provenance
The following attestation bundles were made for efax-2.2.0.tar.gz:
Publisher:
publish.yml on NeilGirdhar/efax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
efax-2.2.0.tar.gz -
Subject digest:
467f502658a8d2d5dd34570c8f8431b1fea8b85ead1596dbbafaa86effaf463e - Sigstore transparency entry: 1210365063
- Sigstore integration time:
-
Permalink:
NeilGirdhar/efax@515c35826af9ad4c24ba04610b411a75d40b41d4 -
Branch / Tag:
refs/tags/v2.2.1 - Owner: https://github.com/NeilGirdhar
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@515c35826af9ad4c24ba04610b411a75d40b41d4 -
Trigger Event:
push
-
Statement type:
File details
Details for the file efax-2.2.0-py3-none-any.whl.
File metadata
- Download URL: efax-2.2.0-py3-none-any.whl
- Upload date:
- Size: 95.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8e57a8a1a9e54aea216f480455c7233071bdd72cf7c381fae8d479e4cccaa48e
|
|
| MD5 |
98354e14559d0fff0176cf34e50cb8e5
|
|
| BLAKE2b-256 |
ad36a078250b45e752a09b0e3c2103ea483755527821225209d19d86b74a5c4e
|
Provenance
The following attestation bundles were made for efax-2.2.0-py3-none-any.whl:
Publisher:
publish.yml on NeilGirdhar/efax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
efax-2.2.0-py3-none-any.whl -
Subject digest:
8e57a8a1a9e54aea216f480455c7233071bdd72cf7c381fae8d479e4cccaa48e - Sigstore transparency entry: 1210365104
- Sigstore integration time:
-
Permalink:
NeilGirdhar/efax@515c35826af9ad4c24ba04610b411a75d40b41d4 -
Branch / Tag:
refs/tags/v2.2.1 - Owner: https://github.com/NeilGirdhar
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@515c35826af9ad4c24ba04610b411a75d40b41d4 -
Trigger Event:
push
-
Statement type: