Skip to main content

Pyro PPL on Numpy

Project description

Build Status Documentation Status Latest Version

NumPyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/CPU.

Docs | Examples | Forum


What is NumPyro?

NumPyro is a small probabilistic programming library built on JAX. It essentially provides a NumPy backend for Pyro, with some minor changes to the inference API and syntax. Since we use JAX, we get autograd and JIT compilation to GPU / CPU for free. This is an alpha release under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.

NumPyro is designed to be lightweight and focuses on providing a flexible substrate that users can build on:

  • Pyro Primitives: NumPyro programs can contain regular Python and NumPy code, in addition to Pyro primitives like sample and param. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See Examples.
  • Inference algorithms: NumPyro currently supports Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integration step that includes multiple gradient computations. With JAX, we can compose jit and grad to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using Iterative NUTS). There is also a basic Variational Inference implementation for reparameterized distributions.
  • Distributions: The numpyro.distributions module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's functional pseudo-random number generator. The design of the distributions module largely follows from PyTorch. A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in torch.distributions. In addition to distributions, constraints and transforms are very useful when operating on distribution classes with bounded support.
  • Effect handlers: Like Pyro, primitives like sample and param can be interpreted with side-effects using effect-handlers from the numpyro.handlers module, and these can be easily extended to implement custom inference algorithms and inference utilities.

Installation

Limited Windows Support: Note that NumPyro is untested on Windows, and will require building jaxlib from source. See this JAX issue for more details.

To install NumPyro with a CPU version of JAX, you can use pip:

pip install numpyro

To use NumPyro on the GPU, you will need to first install jax and jaxlib with CUDA support.

You can also install NumPyro from source:

git clone https://github.com/pyro-ppl/numpyro.git
# install jax/jaxlib first for CUDA support
pip install -e .[dev]

Examples

For some examples on specifying models and doing inference in NumPyro:

Users will note that the API for model specification is largely the same as Pyro including the distributions API, by design. The interface for inference algorithms and other utility functions might deviate from Pyro in favor of a more functional style that works better with JAX. e.g. there is no global parameter store or random state.

Future Work

In the near term, we plan to work on the following. Please open new issues for feature requests and enhancements:

  • Improving robustness of inference on different models, profiling and performance tuning.
  • More inference algorithms, particularly those that require second order derivaties or use HMC.
  • Integration with Funsor to support inference algorithms with delayed sampling.
  • Supporting more distributions, extending the distributions API, and adding more samplers to JAX.
  • Other areas motivated by Pyro's research goals and application focus, and interest from the community.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numpyro-0.2.0.tar.gz (101.7 kB view details)

Uploaded Source

Built Distribution

numpyro-0.2.0-py3-none-any.whl (95.2 kB view details)

Uploaded Python 3

File details

Details for the file numpyro-0.2.0.tar.gz.

File metadata

  • Download URL: numpyro-0.2.0.tar.gz
  • Upload date:
  • Size: 101.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.7.2

File hashes

Hashes for numpyro-0.2.0.tar.gz
Algorithm Hash digest
SHA256 32b2dc6e0dc1c94a0b6590bfba51f605446162072ab73ca69e272dad1007aaf8
MD5 26e76fef203630d51dfb292ca94b3e29
BLAKE2b-256 5afa574b880cb719cd2b2310a6851ff334a2e0cbec1929b3871c16a37988e30f

See more details on using hashes here.

File details

Details for the file numpyro-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: numpyro-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 95.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.7.2

File hashes

Hashes for numpyro-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4f7ed0d83cca73d67e2368132a5dbbb0f84094942f2468173dc1b6c3770a5769
MD5 93d947398501804cd7e21c5dcccd92fa
BLAKE2b-256 177c323ec59c52d6c49defcba944167843437976cbc7261cff058721e74fc77f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page