Skip to main content

SGMCMC samplers in JAX

Project description

SGMCMCJax

Quickstart | Samplers | Documentation

SGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers.

The target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.

Note that this library is still in its early stages so expect the API to change a bit.

Example usage

We show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.

import jax.numpy as jnp
from jax import random
from sgmcmcjax.samplers import build_sgld_sampler


# define model in JAX
def loglikelihood(theta, x):
    return -0.5*jnp.dot(x-theta, x-theta)

def logprior(theta):
    return -0.5*jnp.dot(theta, theta)*0.01

# generate dataset
N, D = 10_000, 100
key = random.PRNGKey(0)
X_data = random.normal(key, shape=(N, D))

# build sampler
batch_size = int(0.1*N)
dt = 1e-5
my_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size)

# run sampler
Nsamples = 10_000
samples = my_sampler(key, Nsamples, jnp.zeros(D))

Samplers

The library includes several SGMCMC algorithms with their pros and cons briefly discussed in the documentation.

The current list of samplers is:

  • SGLD
  • SGLD-CV
  • SVRG-Langevin
  • SGHMC
  • SGHMC-CV
  • SVRG-SGHMC
  • pSGLD
  • SGLDAdam
  • BAOAB
  • SGNHT
  • BADODAB

Installation

Create a virtual environment before installing the package:

git clone https://github.com/jeremiecoullon/SGMCMCJax.git
cd SGMCMCJax
python setup.py develop

Then run the tests with python setup.py test

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

SGMCMCJax-0.1.0.tar.gz (8.5 kB view hashes)

Uploaded Source

Built Distribution

SGMCMCJax-0.1.0-py3-none-any.whl (9.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page