Skip to main content

Surjection layers for density estimation with normalizing flows

Project description

surjectors

status ci version

Surjection layers for density estimation with normalizing flows

About

Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. Surjectors builds on Distrax and Haiku and is fully compatible with both of them.

Surjectors makes use of

  • Haiku`s module system for neural networks,
  • Distrax for probability distributions and some base bijectors,
  • Optax for gradient-based optimization,
  • JAX for autodiff and XLA computation.

Examples

You can, for instance, construct a simple normalizing flow like this:

import distrax
from jax import numpy as jnp
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp

def decoder_fn(n_dim):
    def _fn(z):
        params = make_mlp([32, 32, n_dim * 2])(z)
        means, log_scales = jnp.split(params, 2, -1)
        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
    return _fn

base_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(5))
transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install the package from PyPI, call:

pip install surjectors

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/dirmeier/surjectors@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone Surjectors and install hatch via pip install hatch,
  2. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,
  3. implement your contribution and ideally a test case,
  4. test it by calling hatch run test on the (Unix) command line,
  5. submit a PR 🙂

Author

Simon Dirmeier sfyrbnd @ pm me

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

surjectors-0.3.0.tar.gz (194.3 kB view details)

Uploaded Source

Built Distribution

surjectors-0.3.0-py3-none-any.whl (43.4 kB view details)

Uploaded Python 3

File details

Details for the file surjectors-0.3.0.tar.gz.

File metadata

  • Download URL: surjectors-0.3.0.tar.gz
  • Upload date:
  • Size: 194.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for surjectors-0.3.0.tar.gz
Algorithm Hash digest
SHA256 4a25163e5f0b09200187144d188704d7bc4658bdd7e730744fcfb9c74f9ad2cc
MD5 cb269083fea6a70dcba0603c871a3203
BLAKE2b-256 e9635ab798b462c2c20e4b698b76263c5cb2d7df9ab34b1735a95ce94cb3c5b3

See more details on using hashes here.

Provenance

File details

Details for the file surjectors-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: surjectors-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 43.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for surjectors-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a6940c1974e3116fb189fe2c5292f2f009c36b16a2d61670c033e32de3c871fe
MD5 a39e1c83e01c932e98ebb18ae49c760c
BLAKE2b-256 91cc7fb0bec8d92f6b07f3ec4d92e87897386ee53fe28c142688630ff54de883

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page