Loopy belief propagation for factor graphs on discrete variables in JAX
Project description
PGMax
PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.
- General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
- LBP in JAX: PGMax generates pure JAX functions implementing LBP for a
given factor graph. The generated pure JAX functions run on modern accelerators
(GPU/TPU), work with JAX transformations
(e.g.
vmap
for processing batches of models/samples,grad
for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.
See our companion paper for more details.
PGMax is under active development. APIs may change without notice, and expect rough edges!
Installation | Getting started
Installation
Install from PyPI
pip install pgmax
Install latest version from GitHub
pip install git+https://github.com/deepmind/PGMax.git
Developer
While you can install PGMax in your standard python environment, we strongly recommend using a Python virtual environment to manage your dependencies. This should help to avoid version conflicts and just generally make the installation process easier.
git clone https://github.com/deepmind/PGMax.git
cd PGMax
python3 -m venv pgmax_env
source pgmax_env/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
python3 setup.py develop
Install on GPU
By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.
Getting Started
Here are a few self-contained Colab notebooks to help you get started on using PGMax:
- Tutorial on basic PGMax usage
- LBP inference on Ising model
- Implementing max-product LBP for Recursive Cortical Networks
- End-to-end differentiable LBP for gradient-based PGM training
- 2D binary deconvolution
- Alternative inference with Smooth Dual LP-MAP
Citing PGMax
Please consider citing our companion paper
@article{zhou2022pgmax,
author = {Zhou, Guangyao and Dedieu, Antoine and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
journal = {arXiv preprint arXiv:2202.04110},
year={2022}
}
and using the DeepMind JAX Ecosystem citation if you use PGMax in your work.
Note
This is not an officially supported Google product.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pgmax-0.6.1.tar.gz
.
File metadata
- Download URL: pgmax-0.6.1.tar.gz
- Upload date:
- Size: 51.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
034b676d3a1073c9aaf0c5be0a27309e93df8e651c6b12fd2eb7d4b286e96fe6
|
|
MD5 |
30926e215fc242295ed5dedda770aadb
|
|
BLAKE2b-256 |
c209249683576c7b775e48116bb6eb5fe72437ea5dcbc6ccf4cfdcab8be7bf7a
|
File details
Details for the file pgmax-0.6.1-py3-none-any.whl
.
File metadata
- Download URL: pgmax-0.6.1-py3-none-any.whl
- Upload date:
- Size: 77.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
c5f781bdfb1ad861905b23ce3bfe40edf23a1afece5baff5e123292397b30bb0
|
|
MD5 |
e5ea69372e9fe2a329c35e2e9d3804e0
|
|
BLAKE2b-256 |
61c0e8825f2fc90427a0ac39f496d414d2806782cd9ad4a47fe711f15c8b5663
|