Bayesian layers for NumPyro and Jax
Project description
BLayers
The missing layers package for Bayesian inference.
BLayers is in beta, errors are possible! We invite you to contribute on GitHub.
Write code immediately
pip install blayers
deps are: numpyro, jax, and optax.
Concept
Easily build Bayesian models from parts, abstract away the boilerplate, and tweak priors as you wish.
Inspiration from Keras and Tensorflow Probability, but made specifically for Numpyro + Jax.
BLayers provides tools to
- Quickly build Bayesian models from layers which encapsulate useful model parts
- Fit models either using Variational Inference (VI) or your sampling method of choice without having to rewrite models
- Write pure Numpyro to integrate with all of Numpyro's super powerful tools
- Add more complex layers (model parts) as you wish
- Fit models in a greater variety of ways with less code
The starting point
The simplest non-trivial (and most important!) Bayesian regression model form is the adaptive prior,
lmbda ~ HalfNormal(1)
beta ~ Normal(0, lmbda)
y ~ Normal(beta * x, 1)
BLayers encapsulates a generative model structure like this in a BLayer. The
fundamental building block is the AdaptiveLayer.
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp
def model(x, y):
mu = AdaptiveLayer()('mu', x)
return gaussian_link_exp(mu, y)
All AdaptiveLayer is doing is writing Numpyro for you under the hood. This
model is exacatly equivalent to writing the following, just using way less code.
from numpyro import distributions, sample
def model(x, y):
# Adaptive layer does all of this
input_shape = x.shape[1]
# adaptive prior
lmbda = sample(
name="lmbda",
fn=distributions.HalfNormal(1.),
)
# beta coefficients for regression
beta = sample(
name="beta",
fn=distributions.Normal(loc=0., scale=lmbda),
sample_shape=(input_shape,),
)
mu = jnp.einsum('ij,j->i', x, beta)
# the link function does this
sigma = sample(name='sigma', fn=distributions.Exponential(1.))
return sample('obs', distributions.Normal(mu, sigma), obs=y)
Mixing it up
The AdaptiveLayer is also fully parameterizable via arguments to the class, so let's say you wanted to change the model from
lmbda ~ HalfNormal(1)
beta ~ Normal(0, lmbda)
y ~ Normal(beta * x, 1)
to
lmbda ~ Exponential(1.)
beta ~ LogNormal(0, lmbda)
y ~ Normal(beta * x, 1)
you can just do this directly via arguments
from numpyro import distributions,
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp
def model(x, y):
mu = AdaptiveLayer(
scale_dist=distributions.Exponential,
prior_dist=distributions.LogNormal,
scale_kwargs={'rate': 1.},
prior_kwargs={'loc': 0.}
)('mu', x)
return gaussian_link_exp(mu, y)
"Factories"
Since Numpyro traces sample sites and doesn't record any paramters on the class, you can re-use with a particular generative model structure freely.
from numpyro import distributions
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp
my_lognormal_layer = AdaptiveLayer(
scale_dist=distributions.Exponential,
prior_dist=distributions.LogNormal,
scale_kwargs={'rate': 1.},
prior_kwargs={'loc': 0.}
)
def model(x, y):
mu = my_lognormal_layer('mu1', x) + my_lognormal_layer('mu2', x**2)
return gaussian_link_exp(mu, y)
Layers
The full set of layers included with BLayers:
AdaptiveLayer— Adaptive prior layer.FixedPriorLayer— Fixed prior over coefficients (e.g., Normal or Laplace).InterceptLayer— Intercept-only layer (bias term).EmbeddingLayer— Bayesian embeddings for sparse categorical features.RandomEffectsLayer— Classical random-effects.FMLayer— Factorization Machine (order 2).FM3Layer— Factorization Machine (order 3).LowRankInteractionLayer— Low-rank interaction between two feature sets.RandomWalkLayer— Random walk prior over coefficients (e.g., Gaussian walk).InteractionLayer— All pairwise interactions between two feature sets.BilinearLayer— Bilinear interaction:x^T W z.LowRankBilinearLayer— Low-rank bilinear interaction.HorseshoeLayer— Horseshoe prior for sparse regression.AttentionLayer— Multi-head self-attention over the feature dimension with FT-Transformer tokenisation (Gorishniy et al. 2021).head_dimis per-head so total embedding dim ishead_dim * num_heads— adding heads increases capacity.
All layer prior kwargs are validated at construction time — bad kwargs raise TypeError immediately.
Links
We provide link helpers in links.py to reduce Numpyro boilerplate. Available links:
gaussian_link— Gaussian link with flexible scale: learned (default), fixed, or from a layer (see below).gaussian_link_exp— Gaussian link withExpdistributed homoskedasticsigma.lognormal_link_exp— LogNormal link withExpdistributed homoskedasticsigmalogit_link— Bernoulli link for logistic regression.poission_link— Poisson link with ratey_hat.negative_binomial_link— Usessigma ~ Exponential(rate)andy ~ NegativeBinomial2(mean=y_hat, concentration=sigma).ordinal_link— Cumulative logit / proportional odds for ordinal outcomes.zip_link— Zero-inflated Poisson for count data with excess zeros.beta_link— Beta regression for proportions strictly in (0, 1).
gaussian_link scale modes
# Default: sigma ~ Exp(1) learned from data
gaussian_link(mu, y)
# Fixed known scale (e.g. from XGBoost quantile regression)
gaussian_link(mu, y, scale=pred_std)
# Learned scale from a layer — softplus applied internally for stable gradients
raw = AdaptiveLayer()("log_sigma", x)
gaussian_link(mu, y, untransformed_scale=raw)
Splines
Non-linear transformations via B-splines. Compute the basis matrix once with make_knots + bspline_basis, then pass it to any layer.
from blayers.splines import make_knots, bspline_basis
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link
knots = make_knots(x_train, num_knots=10) # clamped knot vector from data quantiles
def model(x, y=None):
B = bspline_basis(x, knots) # (n, num_basis) design matrix
f = AdaptiveLayer()("f", B)
return gaussian_link(f, y)
Additive models are straightforward:
def model(x1, x2, y=None):
f1 = AdaptiveLayer()("f1", bspline_basis(x1, knots1))
f2 = AdaptiveLayer()("f2", bspline_basis(x2, knots2))
return gaussian_link(f1 + f2, y)
fit() helpers
fit() handles the guide, ELBO, batching, and LR schedule. The same model runs unchanged under VI, MCMC, or SVGD.
from blayers.fit import fit
from blayers.sampling import autoreshape
@autoreshape
def model(x, y=None):
mu = AdaptiveLayer()("beta", x)
intercept = InterceptLayer()("intercept")
return gaussian_link(mu + intercept, y)
# Variational Inference (default)
result = fit(model, y=y, num_steps=1000, batch_size=256, lr=0.01, x=X)
# MCMC
result = fit(model, y=y, method="mcmc", num_mcmc_samples=1000, num_warmup=500, x=X)
# SVGD
result = fit(model, y=y, method="svgd", num_steps=1000, num_particles=20, x=X)
result.predict() returns a Predictions object with .mean, .std, and .samples. result.summary() returns posterior stats per latent variable.
preds = result.predict(x=X, num_samples=500)
summary = result.summary(x=X)
Keyword arguments that are JAX arrays are treated as data (batched during training). Non-array kwargs are bound as constants.
Batched loss
The default Numpyro way to fit batched VI models is to use plate, which confuses
me a lot. Instead, BLayers provides Batched_Trace_ELBO which does not require
you to use plate to batch in VI. Just drop your model in.
from blayers.infer import Batched_Trace_ELBO, svi_run_batched
svi = SVI(model_fn, guide, optax.adam(schedule), loss=loss_instance)
svi_result = svi_run_batched(
svi,
rng_key,
num_steps,
batch_size=1000,
**model_data,
)
⚠️⚠️⚠️ numpyro.plate + Batched_Trace_ELBO do not mix. ⚠️⚠️⚠️
Batched_Trace_ELBO is known to have issues when your model uses numpyro.plate. If your model needs plates, either:
- Batch via
plateand use the standardTrace_ELBO, or - Remove plates and use
Batched_Trace_ELBO+svi_run_batched.
Batched_Trace_ELBO will warn if you if your model has plates.
Reparameterizing
To fit MCMC models well it is crucial to reparamterize. BLayers helps you do this, automatically reparameterizing the following distributions which Numpyro refers to as LocScale distributions.
LocScaleDist = (
dist.Normal
| dist.LogNormal
| dist.StudentT
| dist.Cauchy
| dist.Laplace
| dist.Gumbel
)
Then, reparam these distributions automatically and fit with Numpyro's built in MCMC methods.
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link_exp
from blayers.sampling import autoreparam
data = {...}
@autoreparam
def model(x, y):
mu = AdaptiveLayer()('mu', x)
return gaussian_link_exp(mu, y)
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=500,
num_samples=1000,
num_chains=1,
progress_bar=True,
)
mcmc.run(
rng_key,
**data,
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file blayers-0.2.6.tar.gz.
File metadata
- Download URL: blayers-0.2.6.tar.gz
- Upload date:
- Size: 41.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9b6a5511f69dec5061414086d8ba482d11cda90af399b04f866347f299d8f488
|
|
| MD5 |
dae41448002ce662fe28a1902ea3894b
|
|
| BLAKE2b-256 |
b169678ed28e6235b163d15752506051b084e169f8fade6e13bddfb24454ea7c
|
Provenance
The following attestation bundles were made for blayers-0.2.6.tar.gz:
Publisher:
publish.yml on georgeberry/blayers
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
blayers-0.2.6.tar.gz -
Subject digest:
9b6a5511f69dec5061414086d8ba482d11cda90af399b04f866347f299d8f488 - Sigstore transparency entry: 1060064808
- Sigstore integration time:
-
Permalink:
georgeberry/blayers@c9dda54061eee28c5f276e5095fd4a9e07122fcc -
Branch / Tag:
refs/tags/v0.2.6 - Owner: https://github.com/georgeberry
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@c9dda54061eee28c5f276e5095fd4a9e07122fcc -
Trigger Event:
push
-
Statement type:
File details
Details for the file blayers-0.2.6-py3-none-any.whl.
File metadata
- Download URL: blayers-0.2.6-py3-none-any.whl
- Upload date:
- Size: 44.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7de3291a62ff58dfac1f74322175e3cadf103e9e600afa86d25e4f04e7c2a417
|
|
| MD5 |
59f5b547dfb7bcf4164cdb53f23394c7
|
|
| BLAKE2b-256 |
97b694919b498e8911bded76e9bd766ffbbcedd83aded868cffff7b568f5aea8
|
Provenance
The following attestation bundles were made for blayers-0.2.6-py3-none-any.whl:
Publisher:
publish.yml on georgeberry/blayers
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
blayers-0.2.6-py3-none-any.whl -
Subject digest:
7de3291a62ff58dfac1f74322175e3cadf103e9e600afa86d25e4f04e7c2a417 - Sigstore transparency entry: 1060064852
- Sigstore integration time:
-
Permalink:
georgeberry/blayers@c9dda54061eee28c5f276e5095fd4a9e07122fcc -
Branch / Tag:
refs/tags/v0.2.6 - Owner: https://github.com/georgeberry
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@c9dda54061eee28c5f276e5095fd4a9e07122fcc -
Trigger Event:
push
-
Statement type: