Skip to main content

SGMCMC samplers in JAX

Project description

SGMCMCJax

Quickstart | Samplers | Documentation

SGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers while requiring only JAX to run.

The target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.

DOI

Example usage

We show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.

import jax.numpy as jnp
from jax import random
from sgmcmcjax.samplers import build_sgld_sampler


# define model in JAX
def loglikelihood(theta, x):
    return -0.5*jnp.dot(x-theta, x-theta)

def logprior(theta):
    return -0.5*jnp.dot(theta, theta)*0.01

# generate dataset
N, D = 10_000, 100
key = random.PRNGKey(0)
X_data = random.normal(key, shape=(N, D))

# build sampler
batch_size = int(0.1*N)
dt = 1e-5
my_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size)

# run sampler
Nsamples = 10_000
samples = my_sampler(key, Nsamples, jnp.zeros(D))

Ask a question or open an issue

Please open issues on Github Issue Tracker, or ask a question in the Discussion section on Github.

Samplers

The library includes several SGMCMC algorithms with their pros and cons briefly discussed in the documentation.

The current list of samplers is:

  • SGLD
  • SGLD-CV
  • SVRG-Langevin
  • SGHMC
  • SGHMC-CV
  • SVRG-SGHMC
  • pSGLD
  • SGLDAdam
  • BAOAB
  • SGNHT
  • SGNHT-CV
  • BADODAB
  • BADODAB-CV

Installation

Create a virtual environment and either install a stable version using pip or install the development version.

Stable version

To install the latest stable version run:

pip install sgmcmcjax

Development version

To install the development version run:

git clone https://github.com/jeremiecoullon/SGMCMCJax.git
cd SGMCMCJax
python -m pip install -e .

Then run the tests with pip install -r requirements-dev.txt; make

To run code style checks: make lint

Citing SGMCMCJax

Please use the following bibtex reference to cite this repository:

@article{Coullon2022,
  doi = {10.21105/joss.04113},
  url = {https://doi.org/10.21105/joss.04113},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {72},
  pages = {4113},
  author = {Jeremie Coullon and Christopher Nemeth},
  title = {SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms},
  journal = {Journal of Open Source Software}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

SGMCMCJax-0.2.13.tar.gz (21.6 kB view details)

Uploaded Source

Built Distribution

SGMCMCJax-0.2.13-py3-none-any.whl (27.5 kB view details)

Uploaded Python 3

File details

Details for the file SGMCMCJax-0.2.13.tar.gz.

File metadata

  • Download URL: SGMCMCJax-0.2.13.tar.gz
  • Upload date:
  • Size: 21.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for SGMCMCJax-0.2.13.tar.gz
Algorithm Hash digest
SHA256 633eba94d160014055557bd7cbc53c2f454644e6d123eb64a16e2fbc7c353920
MD5 3f67967fa44e01ee61524faed426ea12
BLAKE2b-256 bef56d4c545db7d2672354245ce03d5d9c518f2262bb00edc1e61cd4e9154a99

See more details on using hashes here.

File details

Details for the file SGMCMCJax-0.2.13-py3-none-any.whl.

File metadata

  • Download URL: SGMCMCJax-0.2.13-py3-none-any.whl
  • Upload date:
  • Size: 27.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for SGMCMCJax-0.2.13-py3-none-any.whl
Algorithm Hash digest
SHA256 5729a36cab4388ae955eef7e780759845f7bafd75367cdb31694ec51d03a283a
MD5 6a7535e2d278564a086be5de2a8ae9d3
BLAKE2b-256 b1aaefa8777af80fe753b540705863934459b56d0bc434e048c4ce3a74fd1a46

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page