Skip to main content

Regularized Stein thinning using JAX

Project description

Kernax: kernel-based MCMC post-processing algorithms

Small package that implements kernel-based post-processing and subsampling algorithms. Three main algortihms are provided:

Quick start

Here's an example of usage of the Stein thinning algorithm applied to a Gaussian sample:

import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from kernax.utils import median_heuristic
from kernax import SteinThinning

rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))

def logprob_fn(x):
    return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)

lengthscale = jnp.array([median_heuristic(x)])
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)

If you want to apply the regularized method, you need a few more lines:

from kernax import laplace_log_p_softplus
from kernax import RegularizedSteinThinning

log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)

reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)

Documentation

Documentation is available at readthedocs.

Contributing

This code is not meant to be an evolving library. However, feel free to create issues and merge requests.

Install guide

As a user

A python wheel is available on PyPi. You can install kernax in your python environment as follows:

pip install kernax

As a developper

We recommand using uv. First clone this repository then simply run

uv sync

This will create a virtual environment that you can use for developping in kernax. If you're not familiar with uv, have a look at their Getting started section.

Paper experiments

This code implements the regularized Stein thinning algorithm introduced in the paper Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization.

Please consider citing the paper when using this library:

@article{benard2023kernel,
  title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
  author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
  journal={arXiv preprint arXiv:2301.13528},
  year={2023}
}

All the numerical experiments presented in the paper can be reproduced with the scripts made available in the example folder.

In particular:

  • Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py

  • Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:

    • Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim
    • Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim
    • Bayesian logistic regression: example/logistic_regression.py
  • Two additional scripts are also available to reproduce figures shown in the supplementary material:

    • Figure 2: example/mog_weight_weights.py
    • Figure 6: example/mog4_mcmc_lambda

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kernax-0.2.0.tar.gz (9.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kernax-0.2.0-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file kernax-0.2.0.tar.gz.

File metadata

  • Download URL: kernax-0.2.0.tar.gz
  • Upload date:
  • Size: 9.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for kernax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 6d7d4d934e583dc182cd7faf22c588a79c8a1be322dae09e5dfbecc524eba009
MD5 9846a3fb02ee13a0a90e331be541deb7
BLAKE2b-256 f87e3f145cbfa856596cdef4b2062e4225ac925ff5a0795f74139bd35ba3e339

See more details on using hashes here.

File details

Details for the file kernax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: kernax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for kernax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cda76b8bed52ff213d2a825b820b5cb1aba4f91e49d5937f04ed07e6aefc2d9c
MD5 a8da0ab7e4947d9640f8414895c6c32a
BLAKE2b-256 0eeafa4329eb00f8e0014bc4d01e1d026fc5f760f0d4212dc8152a4036ebd2c7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page