Regularized Stein thinning using JAX
Project description
Kernax: kernel-based MCMC post-processing algorithms
Small package that implements kernel-based post-processing and subsampling algorithms. Three main algortihms are provided:
- *The vanilla Stein thinning algorithm proposed by M. Riabiz et al. in Optimal thinning of MCMC output
- *The regularized Stein thinning algorithm proposed by by C. Benard et al. in Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization.
- *The greedy maximum mean discrepancy subsampling algorithm (see, e.g., Optimal quantisation of probability measures using maximum mean discrepancy).
Quick start
Here's an example of usage of the Stein thinning algorithm applied to a Gaussian sample:
import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from kernax.utils import median_heuristic
from kernax import SteinThinning
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))
def logprob_fn(x):
return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)
lengthscale = jnp.array([median_heuristic(x)])
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)
If you want to apply the regularized method, you need a few more lines:
from kernax import laplace_log_p_softplus
from kernax import RegularizedSteinThinning
log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)
reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)
Documentation
Documentation is available at readthedocs.
Contributing
This code is not meant to be an evolving library. However, feel free to create issues and merge requests.
Install guide
As a user
A python wheel is available on PyPi. You can install kernax in your python environment as follows:
pip install kernax
As a developper
We recommand using uv. First clone this repository then simply run
uv sync
This will create a virtual environment that you can use for developping in kernax. If you're not familiar with uv, have a look at their Getting started section.
Paper experiments
This code implements the regularized Stein thinning algorithm introduced in the paper Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization.
Please consider citing the paper when using this library:
@article{benard2023kernel,
title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
journal={arXiv preprint arXiv:2301.13528},
year={2023}
}
All the numerical experiments presented in the paper can be reproduced with the scripts made available in the example folder.
In particular:
-
Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py
-
Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:
- Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim
- Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim
- Bayesian logistic regression: example/logistic_regression.py
-
Two additional scripts are also available to reproduce figures shown in the supplementary material:
- Figure 2: example/mog_weight_weights.py
- Figure 6: example/mog4_mcmc_lambda
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kernax-0.2.0.tar.gz.
File metadata
- Download URL: kernax-0.2.0.tar.gz
- Upload date:
- Size: 9.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6d7d4d934e583dc182cd7faf22c588a79c8a1be322dae09e5dfbecc524eba009
|
|
| MD5 |
9846a3fb02ee13a0a90e331be541deb7
|
|
| BLAKE2b-256 |
f87e3f145cbfa856596cdef4b2062e4225ac925ff5a0795f74139bd35ba3e339
|
File details
Details for the file kernax-0.2.0-py3-none-any.whl.
File metadata
- Download URL: kernax-0.2.0-py3-none-any.whl
- Upload date:
- Size: 13.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cda76b8bed52ff213d2a825b820b5cb1aba4f91e49d5937f04ed07e6aefc2d9c
|
|
| MD5 |
a8da0ab7e4947d9640f8414895c6c32a
|
|
| BLAKE2b-256 |
0eeafa4329eb00f8e0014bc4d01e1d026fc5f760f0d4212dc8152a4036ebd2c7
|