Skip to main content

Bayesian layers for NumPyro and Jax

Project description

Coverage Status License PyPI Read - Docs View - GitHub PyPI Downloads

BLayers

The missing layers package for Bayesian inference.

BLayers is in beta, errors are possible! We invite you to contribute on GitHub.

Write code immediately

pip install blayers

deps are: numpyro, jax, and optax.

Concept

image

Easily build Bayesian models from parts, abstract away the boilerplate, and tweak priors as you wish.

Inspiration from Keras and Tensorflow Probability, but made specifically for Numpyro + Jax.

BLayers provides tools to

  • Quickly build Bayesian models from layers which encapsulate useful model parts
  • Fit models either using Variational Inference (VI) or your sampling method of choice without having to rewrite models
  • Write pure Numpyro to integrate with all of Numpyro's super powerful tools
  • Add more complex layers (model parts) as you wish
  • Fit models in a greater variety of ways with less code

The starting point

The simplest non-trivial (and most important!) Bayesian regression model form is the adaptive prior,

scale ~ HalfNormal(1)
beta  ~ Normal(0, scale)
y     ~ Normal(beta * x, 1)

BLayers encapsulates a generative model structure like this in a BLayer. The fundamental building block is the AdaptiveLayer.

from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link

def model(x, y):
    mu = AdaptiveLayer()('mu', x)
    return gaussian_link(mu, y)

All AdaptiveLayer is doing is writing Numpyro for you under the hood. This model is exactly equivalent to writing the following, just using way less code.

import jax.numpy as jnp
from numpyro import distributions, sample

def model(x, y):
    # Adaptive layer does all of this
    input_shape = x.shape[1]
    # adaptive prior
    scale = sample(
        name="scale",
        fn=distributions.HalfNormal(1.),
    )
    # beta coefficients for regression
    beta = sample(
        name="beta",
        fn=distributions.Normal(loc=0., scale=scale),
        sample_shape=(input_shape,),
    )
    mu = jnp.einsum('ij,j->i', x, beta)

    # the link function does this
    sigma = sample(name='sigma', fn=distributions.Exponential(1.))
    return sample('obs', distributions.Normal(mu, sigma), obs=y)

Mixing it up

The AdaptiveLayer is also fully parameterizable via arguments to the class, so let's say you wanted to change the model from

scale ~ HalfNormal(1)
beta  ~ Normal(0, scale)
y     ~ Normal(beta * x, 1)

to

scale ~ Exponential(1.)
beta  ~ LogNormal(0, scale)
y     ~ Normal(beta * x, 1)

you can just do this directly via arguments

from numpyro import distributions
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link

def model(x, y):
    mu = AdaptiveLayer(
        scale_dist=distributions.Exponential,
        coef_dist=distributions.LogNormal,
        scale_kwargs={'rate': 1.},
        coef_kwargs={'loc': 0.}
    )('mu', x)
    return gaussian_link(mu, y)

"Factories"

Since Numpyro traces sample sites and doesn't record any parameters on the class, you can re-use with a particular generative model structure freely.

from numpyro import distributions
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link

my_lognormal_layer = AdaptiveLayer(
    scale_dist=distributions.Exponential,
    coef_dist=distributions.LogNormal,
    scale_kwargs={'rate': 1.},
    coef_kwargs={'loc': 0.}
)

def model(x, y):
    mu = my_lognormal_layer('mu1', x) + my_lognormal_layer('mu2', x**2)
    return gaussian_link(mu, y)

Layers

The full set of layers included with BLayers:

  • AdaptiveLayer — Adaptive prior layer: scale ~ HalfNormal(1), beta ~ Normal(0, scale).
  • FixedPriorLayer — Fixed prior over coefficients (e.g., Normal or Laplace), no hierarchical scale.
  • InterceptLayer — Intercept-only layer (bias term).
  • EmbeddingLayer — Bayesian embeddings for sparse categorical features.
  • RandomEffectsLayer — Classical random-effects (embedding with output dim 1).
  • FMLayer — Factorization Machine (order 2) for pairwise interaction terms.
  • FM3Layer — Factorization Machine (order 3).
  • LowRankInteractionLayer — Low-rank interaction between two feature sets.
  • InteractionLayer — All pairwise interactions between two feature sets.
  • BilinearLayer — Bilinear interaction: x^T W z.
  • LowRankBilinearLayer — Low-rank bilinear interaction.
  • RandomWalkLayer — Gaussian random walk prior over an ordered index (e.g., time).
  • HorseshoeLayer — Horseshoe prior for sparse regression; global-local shrinkage via HalfCauchy.
  • SpikeAndSlabLayer — Spike-and-slab prior; z ~ Beta(0.5, 0.5) inclusion weights times a configurable slab.
  • AttentionLayer — Multi-head self-attention over the feature dimension with FT-Transformer tokenisation (Gorishniy et al. 2021). head_dim is per-head so total embedding dim is head_dim * num_heads — adding heads increases capacity.

All layer prior kwargs are validated at construction time — bad kwargs raise TypeError immediately.

Links

We provide link helpers in links.py to reduce Numpyro boilerplate. Available links:

  • gaussian_link — Gaussian likelihood with configurable sigma prior (see below).
  • lognormal_link — LogNormal likelihood with configurable sigma prior.
  • logit_link — Bernoulli link for logistic regression.
  • poisson_link — Poisson link with log-rate input.
  • negative_binomial_link — NegativeBinomial2 for overdispersed counts; learned concentration via Exponential.
  • ordinal_link — Cumulative logit / proportional odds for ordinal outcomes.
  • zip_link — Zero-inflated Poisson for count data with excess zeros.
  • beta_link — Beta regression for proportions strictly in (0, 1).

gaussian_link and lognormal_link

Both links are built on a common base and support three scale modes:

from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link

# Default: sigma ~ Exp(1) learned from data
gaussian_link(mu, y)

# Fixed known scale (e.g. from XGBoost quantile regression)
gaussian_link(mu, y, scale=pred_std)

# Learned scale from a layer — softplus applied internally for stable gradients
raw = AdaptiveLayer()("log_sigma", x)
gaussian_link(mu, y, untransformed_scale=raw)

Swap the sigma prior via functools.partial:

from functools import partial
import numpyro.distributions as dists
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link

# HalfNormal prior instead of Exponential
hn_gaussian = partial(gaussian_link, sigma_dist=dists.HalfNormal, sigma_kwargs={"scale": 1.0})

def model(x, y=None):
    mu = AdaptiveLayer()("mu", x)
    return hn_gaussian(mu, y)

Splines

Non-linear transformations via B-splines. Compute the basis matrix once with make_knots + bspline_basis, then pass it to any layer.

from blayers.splines import make_knots, bspline_basis
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link

knots = make_knots(x_train, num_knots=10)   # clamped knot vector from data quantiles

def model(x, y=None):
    B = bspline_basis(x, knots)             # (n, num_basis) design matrix
    f = AdaptiveLayer()("f", B)
    return gaussian_link(f, y)

Additive models are straightforward:

knots1 = make_knots(x1_train, num_knots=10)
knots2 = make_knots(x2_train, num_knots=10)

def model(x1, x2, y=None):
    f1 = AdaptiveLayer()("f1", bspline_basis(x1, knots1))
    f2 = AdaptiveLayer()("f2", bspline_basis(x2, knots2))
    return gaussian_link(f1 + f2, y)

fit() helpers

fit() handles the guide, ELBO, batching, and LR schedule. The same model runs unchanged under VI, MCMC, or SVGD.

from blayers.fit import fit
from blayers.decorators import autoreshape
from blayers.layers import AdaptiveLayer, InterceptLayer
from blayers.links import gaussian_link

@autoreshape
def model(x, y=None):
    mu = AdaptiveLayer()("beta", x)
    intercept = InterceptLayer()("intercept")
    return gaussian_link(mu + intercept, y)

# Variational Inference (default)
result = fit(model, y=y, num_steps=1000, batch_size=256, lr=0.01, x=X)

# MCMC
result = fit(model, y=y, method="mcmc", num_mcmc_samples=1000, num_warmup=500, x=X)

# SVGD
result = fit(model, y=y, method="svgd", num_steps=1000, num_particles=20, x=X)

result.predict() returns a Predictions object with .mean, .std, and .samples. result.summary() returns posterior stats per latent variable.

preds = result.predict(x=X, num_samples=500)
summary = result.summary(x=X)

Keyword arguments that are JAX arrays are treated as data (batched during training). Non-array kwargs are bound as constants.

Batched loss

The default Numpyro way to fit batched VI models is to use plate, which confuses me a lot. Instead, BLayers provides Batched_Trace_ELBO which does not require you to use plate to batch in VI. Just drop your model in.

from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDiagonalNormal
import optax
from blayers.vi_infer import Batched_Trace_ELBO, svi_run_batched

loss = Batched_Trace_ELBO(num_obs=len(y), batch_size=1000)
guide = AutoDiagonalNormal(model_fn)
svi = SVI(model_fn, guide, optax.adam(0.01), loss=loss)

svi_result = svi_run_batched(
    svi,
    rng_key,
    batch_size=1000,
    num_steps=500,
    **model_data,
)

⚠️⚠️⚠️ numpyro.plate + Batched_Trace_ELBO do not mix. ⚠️⚠️⚠️

Batched_Trace_ELBO is known to have issues when your model uses numpyro.plate. If your model needs plates, either:

  1. Batch via plate and use the standard Trace_ELBO, or
  2. Remove plates and use Batched_Trace_ELBO + svi_run_batched.

Batched_Trace_ELBO will warn if your model has plates.

Reparameterizing

To fit MCMC models well it is crucial to reparameterize. BLayers helps you do this via @autoreparam, which automatically applies LocScaleReparam to all LocScale distributions in your model (Normal, LogNormal, StudentT, Cauchy, Laplace, Gumbel).

from numpyro.infer import MCMC, NUTS
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link
from blayers.decorators import autoreparam

data = {...}

@autoreparam
def model(x, y):
    mu = AdaptiveLayer()('mu', x)
    return gaussian_link(mu, y)

kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=500,
    num_samples=1000,
    num_chains=1,
    progress_bar=True,
)
mcmc.run(
    rng_key,
    **data,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blayers-0.3.0.tar.gz (45.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blayers-0.3.0-py3-none-any.whl (47.3 kB view details)

Uploaded Python 3

File details

Details for the file blayers-0.3.0.tar.gz.

File metadata

  • Download URL: blayers-0.3.0.tar.gz
  • Upload date:
  • Size: 45.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blayers-0.3.0.tar.gz
Algorithm Hash digest
SHA256 6d61cbac370b34050c807b9000c73d31190d7561bdd76b2bdacbf1ed592826cd
MD5 39d368e163c1a5506cf3986a662e269e
BLAKE2b-256 6404d525b626b1480f6dd20f0540f170b3b028ca6920381c660a9ed526f625ee

See more details on using hashes here.

Provenance

The following attestation bundles were made for blayers-0.3.0.tar.gz:

Publisher: publish.yml on georgeberry/blayers

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file blayers-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: blayers-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 47.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blayers-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1ed5f67598e7ac1627de37ce66b8f35c184d7376934dd8383e73c18e965cc05a
MD5 74f6cd488284e8a465ed0b9b3222fe91
BLAKE2b-256 7e7d2c4f417ddc595a73d02bd15d0cc8f48e542e090a8df8b0d7cc41bf7918fb

See more details on using hashes here.

Provenance

The following attestation bundles were made for blayers-0.3.0-py3-none-any.whl:

Publisher: publish.yml on georgeberry/blayers

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page