Skip to main content

Bayesian layers for NumPyro and Jax

Project description

Coverage Status License PyPI Read - Docs View - GitHub PyPI Downloads

BLayers

The missing layers package for Bayesian inference.

BLayers is in beta, errors are possible! We invite you to contribute on GitHub.

Write code immediately

pip install blayers

deps are: numpyro, jax, and optax.

Concept

Easily build Bayesian models from parts, abstract away the boilerplate, and tweak priors as you wish.

Inspiration from Keras and Tensorflow Probability, but made specifically for Numpyro + Jax.

BLayers provides tools to

  • Quickly build Bayesian models from layers which encapsulate useful model parts
  • Fit models either using Variational Inference (VI) or your sampling method of choice without having to rewrite models
  • Write pure Numpyro to integrate with all of Numpyro's super powerful tools
  • Add more complex layers (model parts) as you wish
  • Fit models in a greater variety of ways with less code

The starting point

The simplest non-trivial (and most important!) Bayesian regression model form is the adaptive prior,

lmbda ~ HalfNormal(1)
beta  ~ Normal(0, lmbda)
y     ~ Normal(beta * x, 1)

BLayers encapsulates a generative model structure like this in a BLayer. The fundamental building block is the AdaptiveLayer.

from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp
def model(x, y):
    mu = AdaptiveLayer()('mu', x)
    return gaussian_link_exp(mu, y)

All AdaptiveLayer is doing is writing Numpyro for you under the hood. This model is exacatly equivalent to writing the following, just using way less code.

from numpyro import distributions, sample

def model(x, y):
    # Adaptive layer does all of this
    input_shape = x.shape[1]
    # adaptive prior
    lmbda = sample(
        name="lmbda",
        fn=distributions.HalfNormal(1.),
    )
    # beta coefficients for regression
    beta = sample(
        name="beta",
        fn=distributions.Normal(loc=0., scale=lmbda),
        sample_shape=(input_shape,),
    )
    mu = jnp.einsum('ij,j->i', x, beta)

    # the link function does this
    sigma = sample(name='sigma', fn=distributions.Exponential(1.))
    return sample('obs', distributions.Normal(mu, sigma), obs=y)

Mixing it up

The AdaptiveLayer is also fully parameterizable via arguments to the class, so let's say you wanted to change the model from

lmbda ~ HalfNormal(1)
beta  ~ Normal(0, lmbda)
y     ~ Normal(beta * x, 1)

to

lmbda ~ Exponential(1.)
beta  ~ LogNormal(0, lmbda)
y     ~ Normal(beta * x, 1)

you can just do this directly via arguments

from numpyro import distributions,
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp
def model(x, y):
    mu = AdaptiveLayer(
        lmbda_dist=distributions.Exponential,
        prior_dist=distributions.LogNormal,
        lmbda_kwargs={'rate': 1.},
        prior_kwargs={'loc': 0.}
    )('mu', x)
    return gaussian_link_exp(mu, y)

"Factories"

Since Numpyro traces sample sites and doesn't record any paramters on the class, you can re-use with a particular generative model structure freely.

from numpyro import distributions
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp

my_lognormal_layer = AdaptiveLayer(
    lmbda_dist=distributions.Exponential,
    prior_dist=distributions.LogNormal,
    lmbda_kwargs={'rate': 1.},
    prior_kwargs={'loc': 0.}
)

def model(x, y):
    mu = my_lognormal_layer('mu1', x) + my_lognormal_layer('mu2', x**2)
    return gaussian_link_exp(mu, y)

Layers

The full set of layers included with BLayers:

  • AdaptiveLayer — Adaptive prior layer.
  • FixedPriorLayer — Fixed prior over coefficients (e.g., Normal or Laplace).
  • InterceptLayer — Intercept-only layer (bias term).
  • EmbeddingLayer — Bayesian embeddings for sparse categorical features.
  • RandomEffectsLayer — Classical random-effects.
  • FMLayer — Factorization Machine (order 2).
  • FM3Layer — Factorization Machine (order 3).
  • LowRankInteractionLayer — Low-rank interaction between two feature sets.
  • RandomWalkLayer — Random walk prior over coefficients (e.g., Gaussian walk).
  • InteractionLayer — All pairwise interactions between two feature sets.

Links

We provide link helpers in links.py to reduce Numpyro boilerplate. Available links:

  • logit_link — Bernoulli link for logistic regression.
  • poission_link — Poisson link with rate y_hat.
  • gaussian_link_exp — Gaussian link with Exp distributed homoskedastic sigma.
  • lognormal_link_exp — LogNormal link with Exp distributed homoskedastic sigma
  • negative_binomial_link — Uses sigma ~ Exponential(rate) and y ~ NegativeBinomial2(mean=y_hat, concentration=sigma).

Batched loss

The default Numpyro way to fit batched VI models is to use plate, which confuses me a lot. Instead, BLayers provides Batched_Trace_ELBO which does not require you to use plate to batch in VI. Just drop your model in.

from blayers.infer import Batched_Trace_ELBO, svi_run_batched

svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)

svi_result = svi_run_batched(
    svi,
    rng_key,
    num_steps,
    batch_size=1000,
    **model_data,
)

⚠️⚠️⚠️ numpyro.plate + Batched_Trace_ELBO do not mix. ⚠️⚠️⚠️

Batched_Trace_ELBO is known to have issues when your model uses numpyro.plate. If your model needs plates, either:

  1. Batch via plate and use the standard Trace_ELBO, or
  2. Remove plates and use Batched_Trace_ELBO + svi_run_batched.

Batched_Trace_ELBO will warn if you if your model has plates.

Reparameterizing

To fit MCMC models well it is crucial to reparamterize. BLayers helps you do this, automatically reparameterizing the following distributions which Numpyro refers to as LocScale distributions.

LocScaleDist = (
    dist.Normal
    | dist.LogNormal
    | dist.StudentT
    | dist.Cauchy
    | dist.Laplace
    | dist.Gumbel
)

Then, reparam these distributions automatically and fit with Numpyro's built in MCMC methods.

from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp
from blayers.sampling import autoreparam

data = {...}

@autoreparam
def model(x, y):
    mu = AdaptiveLayer()('mu', x)
    return gaussian_link_exp(mu, y)

kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=500,
    num_samples=1000,
    num_chains=1,
    progress_bar=True,
)
    mcmc.run(
        rng_key,
        **data,
    )

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blayers-0.2.4.tar.gz (39.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blayers-0.2.4-py3-none-any.whl (41.9 kB view details)

Uploaded Python 3

File details

Details for the file blayers-0.2.4.tar.gz.

File metadata

  • Download URL: blayers-0.2.4.tar.gz
  • Upload date:
  • Size: 39.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blayers-0.2.4.tar.gz
Algorithm Hash digest
SHA256 83a17d6a6bcfe6e4b22ea9883287b05df59e9ae81318bb3bdd11494cbc458a13
MD5 2218b9283dc0c024d271a87184c8fa20
BLAKE2b-256 030cfebdd636a9638e140b3b41d346ad00b1264a734538374ef51e4388b8b1b2

See more details on using hashes here.

Provenance

The following attestation bundles were made for blayers-0.2.4.tar.gz:

Publisher: publish.yml on georgeberry/blayers

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file blayers-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: blayers-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 41.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blayers-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 32a891f7cb9b3cb5710b514b3a049fd70a89dac4d07f5cff324ae61e6787659c
MD5 caac8702ee3de76ca46115151fb4979e
BLAKE2b-256 fb34fcd184142962be763e426588979beb5c7fbd7809735ca258207fa2fbd0ea

See more details on using hashes here.

Provenance

The following attestation bundles were made for blayers-0.2.4-py3-none-any.whl:

Publisher: publish.yml on georgeberry/blayers

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page