Surjection layers for density estimation with normalizing flows
Project description
surjectors
Surjection layers for density estimation with normalizing flows
About
Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. Surjectors builds on Distrax and Haiku and is fully compatible with both of them.
Surjectors makes use of
- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.
Examples
You can, for instance, construct a simple normalizing flow like this:
import distrax
from jax import numpy as jnp
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp
def decoder_fn(n_dim):
def _fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
means, log_scales = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
return _fn
base_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(5))
transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
More self-contained examples can be found in examples.
Documentation
Documentation can be found here.
Installation
Make sure to have a working JAX
installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install the package from PyPI, call:
pip install surjectors
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
Contributing
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.
In order to contribute:
- Clone
Surjectors
and installhatch
viapip install hatch
, - create a new branch locally
git checkout -b feature/my-new-feature
orgit checkout -b issue/fixes-bug
, - implement your contribution and ideally a test case,
- test it by calling
hatch run test
on the (Unix) command line, - submit a PR 🙂
Author
Simon Dirmeier sfyrbnd @ pm me
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file surjectors-0.3.0.tar.gz
.
File metadata
- Download URL: surjectors-0.3.0.tar.gz
- Upload date:
- Size: 194.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4a25163e5f0b09200187144d188704d7bc4658bdd7e730744fcfb9c74f9ad2cc |
|
MD5 | cb269083fea6a70dcba0603c871a3203 |
|
BLAKE2b-256 | e9635ab798b462c2c20e4b698b76263c5cb2d7df9ab34b1735a95ce94cb3c5b3 |
Provenance
File details
Details for the file surjectors-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: surjectors-0.3.0-py3-none-any.whl
- Upload date:
- Size: 43.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a6940c1974e3116fb189fe2c5292f2f009c36b16a2d61670c033e32de3c871fe |
|
MD5 | a39e1c83e01c932e98ebb18ae49c760c |
|
BLAKE2b-256 | 91cc7fb0bec8d92f6b07f3ec4d92e87897386ee53fe28c142688630ff54de883 |