SGMCMC samplers in JAX
Project description
SGMCMCJax
Quickstart | Samplers | Documentation
SGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers while requiring only JAX to run.
The target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.
Example usage
We show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.
import jax.numpy as jnp
from jax import random
from sgmcmcjax.samplers import build_sgld_sampler
# define model in JAX
def loglikelihood(theta, x):
return -0.5*jnp.dot(x-theta, x-theta)
def logprior(theta):
return -0.5*jnp.dot(theta, theta)*0.01
# generate dataset
N, D = 10_000, 100
key = random.PRNGKey(0)
X_data = random.normal(key, shape=(N, D))
# build sampler
batch_size = int(0.1*N)
dt = 1e-5
my_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size)
# run sampler
Nsamples = 10_000
samples = my_sampler(key, Nsamples, jnp.zeros(D))
Ask a question or open an issue
Please open issues on Github Issue Tracker, or ask a question in the Discussion section on Github.
Samplers
The library includes several SGMCMC algorithms with their pros and cons briefly discussed in the documentation.
The current list of samplers is:
- SGLD
- SGLD-CV
- SVRG-Langevin
- SGHMC
- SGHMC-CV
- SVRG-SGHMC
- pSGLD
- SGLDAdam
- BAOAB
- SGNHT
- SGNHT-CV
- BADODAB
- BADODAB-CV
Installation
Create a virtual environment and either install a stable version using pip or install the development version.
Stable version
To install the latest stable version run:
pip install sgmcmcjax
Development version
To install the development version run:
git clone https://github.com/jeremiecoullon/SGMCMCJax.git
cd SGMCMCJax
python -m pip install -e .
Then run the tests with pip install -r requirements-dev.txt; make
To run code style checks: make lint
Citing SGMCMCJax
Please use the following bibtex reference to cite this repository:
@article{Coullon2022,
doi = {10.21105/joss.04113},
url = {https://doi.org/10.21105/joss.04113},
year = {2022},
publisher = {The Open Journal},
volume = {7},
number = {72},
pages = {4113},
author = {Jeremie Coullon and Christopher Nemeth},
title = {SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms},
journal = {Journal of Open Source Software}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file SGMCMCJax-0.2.13.tar.gz
.
File metadata
- Download URL: SGMCMCJax-0.2.13.tar.gz
- Upload date:
- Size: 21.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 633eba94d160014055557bd7cbc53c2f454644e6d123eb64a16e2fbc7c353920 |
|
MD5 | 3f67967fa44e01ee61524faed426ea12 |
|
BLAKE2b-256 | bef56d4c545db7d2672354245ce03d5d9c518f2262bb00edc1e61cd4e9154a99 |
File details
Details for the file SGMCMCJax-0.2.13-py3-none-any.whl
.
File metadata
- Download URL: SGMCMCJax-0.2.13-py3-none-any.whl
- Upload date:
- Size: 27.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5729a36cab4388ae955eef7e780759845f7bafd75367cdb31694ec51d03a283a |
|
MD5 | 6a7535e2d278564a086be5de2a8ae9d3 |
|
BLAKE2b-256 | b1aaefa8777af80fe753b540705863934459b56d0bc434e048c4ce3a74fd1a46 |