Skip to main content

SGMCMC samplers in JAX

Project description

SGMCMCJax

Quickstart | Samplers | Documentation

SGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers.

The target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.

Note that this library is still in its early stages so expect the API to change a bit.

Example usage

We show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.

import jax.numpy as jnp
from jax import random
from sgmcmcjax.samplers import build_sgld_sampler


# define model in JAX
def loglikelihood(theta, x):
    return -0.5*jnp.dot(x-theta, x-theta)

def logprior(theta):
    return -0.5*jnp.dot(theta, theta)*0.01

# generate dataset
N, D = 10_000, 100
key = random.PRNGKey(0)
X_data = random.normal(key, shape=(N, D))

# build sampler
batch_size = int(0.1*N)
dt = 1e-5
my_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size)

# run sampler
Nsamples = 10_000
samples = my_sampler(key, Nsamples, jnp.zeros(D))

Samplers

The library includes several SGMCMC algorithms with their pros and cons briefly discussed in the documentation.

The current list of samplers is:

  • SGLD
  • SGLD-CV
  • SVRG-Langevin
  • SGHMC
  • SGHMC-CV
  • SVRG-SGHMC
  • pSGLD
  • SGLDAdam
  • BAOAB
  • SGNHT
  • SGNHT-CV
  • BADODAB
  • BADODAB-CV

Installation

Create a virtual environment and either install a stable version using pip or install the development version.

Stable version

To install the latest stable version run:

pip install sgmcmcjax

Development version

To install the development version run:

git clone https://github.com/jeremiecoullon/SGMCMCJax.git
cd SGMCMCJax
python -m pip install -e .

Then run the tests with pip install -r requirements-dev.txt; make

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

SGMCMCJax-0.2.11.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

SGMCMCJax-0.2.11-py3-none-any.whl (31.1 kB view details)

Uploaded Python 3

File details

Details for the file SGMCMCJax-0.2.11.tar.gz.

File metadata

  • Download URL: SGMCMCJax-0.2.11.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.8.2

File hashes

Hashes for SGMCMCJax-0.2.11.tar.gz
Algorithm Hash digest
SHA256 324fed707c2a96618e8ba02d5071bd21c9b7d962493bc5794e8391fe245eb394
MD5 9ac1d6f0840ce19b739a087807810461
BLAKE2b-256 33e355d3f2b1bba1d4ad21c4e401c8d1e955521ffc614037db3f5edc772b9a8e

See more details on using hashes here.

File details

Details for the file SGMCMCJax-0.2.11-py3-none-any.whl.

File metadata

  • Download URL: SGMCMCJax-0.2.11-py3-none-any.whl
  • Upload date:
  • Size: 31.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.8.2

File hashes

Hashes for SGMCMCJax-0.2.11-py3-none-any.whl
Algorithm Hash digest
SHA256 bd1e61e2e56ec5abe64fdf19f49945cd54eb1e741b97dbff0f407a28f2920669
MD5 031a141e50f36006e978b23e70fe065d
BLAKE2b-256 b1ab4ae88fa8f2a9563d94b0203629406434bc81261fae17fe2197eccaaf75ff

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page