Skip to main content

Regularized Stein thinning using JAX

Project description

Kernel-based MCMC post-processing algorithms

Kernax is a small package that implements kernel-based post-processing and subsampling algorithms for MCMC output. It currently provides three algorithms:

Documentation

Full documentation is available on Read the Docs.

Quick start

Example usage of Stein thinning on a Gaussian sample:

import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from kernax.utils import median_heuristic
from kernax import SteinThinning

rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))

def logprob_fn(x):
    return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)

lengthscale = jnp.array([median_heuristic(x)])
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)

To use the regularized variant, add a few lines:

from kernax.utils import laplace_log_p_softplus
from kernax import RegularizedSteinThinning

log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)

reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)

Install guide

As a user

A Python wheel is available on PyPi. Install Kernax into your Python environment with:

pip install kernax

As a developper

We recommand using uv. Clone the repository, then run:

uv sync

This creates a virtual environment for developing Kernax. If you’re not familiar with uv, have a look at their Getting started guide.

Paper experiments

This repository implements the regularized Stein thinning algorithm introduced in Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization.

If you use this library, please consider citing:

@article{benard2023kernel,
  title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
  author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
  journal={arXiv preprint arXiv:2301.13528},
  year={2023}
}

All numerical experiments presented in the paper can be reproduced using the scripts in the example/ folder.

In particular:

  • Figures 1–3: example/mog_randn.py
  • Section 4 and Appendix 1:
    • Gaussian mixture: example/mog4_mcmc/ and example/mog4_mcmc_dim/
    • Mixture of banana-shaped distributions: example/mobt2_mcmc/ and example/mobt2_mcmc_dim/
    • Bayesian logistic regression: example/logistic_regression.py
  • Supplementary material:
    • Figure 2: example/mog_weight_weights.py
    • Figure 6: example/mog4_mcmc_lambda

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kernax-0.3.0.tar.gz (7.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kernax-0.3.0-py3-none-any.whl (10.6 kB view details)

Uploaded Python 3

File details

Details for the file kernax-0.3.0.tar.gz.

File metadata

  • Download URL: kernax-0.3.0.tar.gz
  • Upload date:
  • Size: 7.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for kernax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 cef1d80dad79ae71e722959ee4ff8a3913398a2ad7a035df2c777f110cd2ac1d
MD5 e7db38a014ed22bba39bd833f429bdc3
BLAKE2b-256 5791575995807085942b13845b9437771fd6e62fd927a3d28ec2db69ab8f0169

See more details on using hashes here.

File details

Details for the file kernax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: kernax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 10.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for kernax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9daa2bf05d55c39d98cae55f61fd73c48a5761dcdd38780f73d74d7cde564700
MD5 4e5042fcd885a7f829c1fc3ed1c73855
BLAKE2b-256 77726ef75fc73f7045e510bd20228746a023eb97633590ec33a4066030f42b05

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page