Skip to main content

Stochastic Gradient Monte Carlo in Jax

Project description

Modular Stochastic Gradient MCMC for Jax

Paper | Introduction | Implemented Solvers | Features | Installation | Contributing

CI Documentation Status PyPI version License

Introduction

JaxSGMC brings Stochastic Gradient Markov chain Monte Carlo (SGMCMC) samplers to JAX. Inspired by optax, JaxSGMC is built on a modular concept to increase reusability and accelerate research of new SGMCMC solvers. Additionally, JaxSGMC aims to promote probabilistic machine learning by removing obstacles in switching from stochastic optimizers to SGMCMC samplers.

Quickstart with solvers from alias.py

To get started quickly using SGMCMC samplers, JaxSGMC provides some popular pre-built samplers in alias.py:

Features

Modular SGMCMC solvers

JaxSGMC aims to increase reusability of SGMCMC components via a toolbox of helper functions and a modular concept:

In the simplest case of employing a pre-built sampler from alias.py, the user only needs to provide the computational model, consisting of functions for Prior and Likelihood. Schedulers allow to change sampler properies over the course of the training. Advanced users may build custom samplers from given components.

Data Input / Output under jit

JaxSGMC provides a toolbox to pass reference data to the computation and save collected samples from the Markov chain.

By combining different data loader / collector classes and general wrappers it is possible to read data from and save samples to different data types via the mechanisms of JAX's Host-Callback module. It is therefore also possible to access datasets bigger than the device memory.

Saving Data:

  • HDF5
  • Numpy .npz

Loading Data:

  • HDF5
  • Numpy arrays
  • Tensorflow datasets

Computing the stochastic potential

Stochastic Gradient MCMC requires the evaluation of a potential function for a batch of data. JaxSGMC allows to compute this potential from likelihoods accepting only single observations and batches them automatically with sequential, parallel or vectorized execution. Moreover, JaxSGMC supports passing a model state between the evaluations of the likelihood function, which is saved corresponding to the samples, speeding up postprocessing.

Installation

Basic Setup

JaxSGMC can be installed via pip:

pip install jax-sgmc --upgrade

The above command installs Jax for CPU. To run JaxSGMC on the GPU, the GPU version of JAX has to be installed. Further information can be found here: Jax Installation Instructions

Additional Packages

Some parts of JaxSGMC require additional packages:

  • Data Loading with tensorflow:
    pip install jax-sgmc[tensorflow] --upgrade
    
  • Saving Samples in the HDF5-Format:
    pip install jax-sgmc[hdf5] --upgrade
    

Installation from Source

For development purposes, JaxSGMC can be installed from source in editable mode:

git clone git@github.com:tummfm/jax-sgmc.git
pip install -e .[test,docs]

This command additionally installs the requirements to run the tests:

pytest tests

And to build the documentation (e.g. in html):

make -C docs html

Contributing

Contributions are always welcome! Please open a pull request to discuss the code additions.

Citation

If you use JaxSGMC in your own work, please consider citing

@article{jaxsgmc2024,
title = {JaxSGMC: Modular stochastic gradient MCMC in JAX},
journal = {SoftwareX},
volume = {26},
pages = {101722},
year = {2024},
issn = {2352-7110},
doi = {https://doi.org/10.1016/j.softx.2024.101722},
url = {https://www.sciencedirect.com/science/article/pii/S2352711024000931},
author = {Stephan Thaler and Paul Fuchs and Ana Cukarska and Julija Zavadlav},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_sgmc-0.1.5.tar.gz (70.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_sgmc-0.1.5-py3-none-any.whl (69.0 kB view details)

Uploaded Python 3

File details

Details for the file jax_sgmc-0.1.5.tar.gz.

File metadata

  • Download URL: jax_sgmc-0.1.5.tar.gz
  • Upload date:
  • Size: 70.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for jax_sgmc-0.1.5.tar.gz
Algorithm Hash digest
SHA256 7084b222eee20ba4d5fbbd64783fc807da6b31bde276762978c1960ebf38d965
MD5 9eaaf910eac80c0aceb53a870d459a76
BLAKE2b-256 e563673f9ed1111cbd9bd08a307b5c43f85abbae46ac6a94a37a3d0b48a82373

See more details on using hashes here.

File details

Details for the file jax_sgmc-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: jax_sgmc-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 69.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for jax_sgmc-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 8beb1ba1ad63b7890b1e73231f4a2ddc88ec636f4add3ef23299b4211e675a95
MD5 e928a5f01e3ada11ca175749586f8582
BLAKE2b-256 f7f758f2fcfbf90f203853adca0c5082d6494259fb7a45a0c19334352cdbbcc1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page