Regularized Stein thinning using JAX
Project description
Kernel-based MCMC post-processing algorithms
Kernax is a small package that implements kernel-based post-processing and subsampling algorithms for MCMC output. It currently provides three algorithms:
- The vanilla Stein thinning algorithm, proposed by M. Riabiz et al. in Optimal thinning of MCMC output
- The regularized Stein thinning algorithm, proposed by C. Bénard et al. in Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization.
- A greedy maximum mean discrepancy (MMD) subsampling algorithm (see, e.g., Optimal quantisation of probability measures using maximum mean discrepancy).
Documentation
Full documentation is available on Read the Docs.
Quick start
Example usage of Stein thinning on a Gaussian sample:
import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from kernax.utils import median_heuristic
from kernax import SteinThinning
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))
def logprob_fn(x):
return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)
lengthscale = jnp.array([median_heuristic(x)])
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)
To use the regularized variant, add a few lines:
from kernax.utils import laplace_log_p_softplus
from kernax import RegularizedSteinThinning
log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)
reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)
Install guide
As a user
A Python wheel is available on PyPi. Install Kernax into your Python environment with:
pip install kernax
As a developper
We recommand using uv. Clone the repository, then run:
uv sync
This creates a virtual environment for developing Kernax. If you’re not familiar with uv, have a look at their Getting started guide.
Paper experiments
This repository implements the regularized Stein thinning algorithm introduced in Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization.
If you use this library, please consider citing:
@article{benard2023kernel,
title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
journal={arXiv preprint arXiv:2301.13528},
year={2023}
}
All numerical experiments presented in the paper can be reproduced using the scripts in the example/ folder.
In particular:
- Figures 1–3:
example/mog_randn.py - Section 4 and Appendix 1:
- Gaussian mixture:
example/mog4_mcmc/andexample/mog4_mcmc_dim/ - Mixture of banana-shaped distributions:
example/mobt2_mcmc/andexample/mobt2_mcmc_dim/ - Bayesian logistic regression:
example/logistic_regression.py
- Gaussian mixture:
- Supplementary material:
- Figure 2:
example/mog_weight_weights.py - Figure 6:
example/mog4_mcmc_lambda
- Figure 2:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kernax-0.3.0.tar.gz.
File metadata
- Download URL: kernax-0.3.0.tar.gz
- Upload date:
- Size: 7.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cef1d80dad79ae71e722959ee4ff8a3913398a2ad7a035df2c777f110cd2ac1d
|
|
| MD5 |
e7db38a014ed22bba39bd833f429bdc3
|
|
| BLAKE2b-256 |
5791575995807085942b13845b9437771fd6e62fd927a3d28ec2db69ab8f0169
|
File details
Details for the file kernax-0.3.0-py3-none-any.whl.
File metadata
- Download URL: kernax-0.3.0-py3-none-any.whl
- Upload date:
- Size: 10.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9daa2bf05d55c39d98cae55f61fd73c48a5761dcdd38780f73d74d7cde564700
|
|
| MD5 |
4e5042fcd885a7f829c1fc3ed1c73855
|
|
| BLAKE2b-256 |
77726ef75fc73f7045e510bd20228746a023eb97633590ec33a4066030f42b05
|