Skip to main content

Surjection layers for density estimation with normalizing flows

Project description

surjectors

active ci version doi

Surjection layers for density estimation with normalizing flows

About

Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. Surjectors makes use of

  • Haiku`s module system for neural networks,
  • Distrax for probability distributions and some base bijectors,
  • Optax for gradient-based optimization,
  • JAX for autodiff and XLA computation.

Examples

You can, for instance, construct a simple normalizing flow like this:

import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp

def decoder_fn(n_dim):
    def _fn(z):
        params = make_mlp([32, 32, n_dim * 2])(z)
        means, log_scales = jnp.split(params, 2, -1)
        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
    return _fn

@hk.without_apply_rng
@hk.transform
def flow(x):
    base_distribution = distrax.Independent(
        distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
    )
    transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
    pushforward = TransformedDistribution(base_distribution, transform)
    return pushforward.log_prob(x)

x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install the package from PyPI, call:

pip install surjectors

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/dirmeier/surjectors@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone Surjectors and install hatch via pip install hatch,
  2. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,
  3. implement your contribution and ideally a test case,
  4. test it by calling hatch run test on the (Unix) command line,
  5. submit a PR 🙂

Citing Surjectors

If you find our work relevant to your research, please consider citing:

@article{dirmeier2024surjectors,
    author = {Simon Dirmeier},
    title = {Surjectors: surjection layers for density estimation with normalizing flows},
    year = {2024},
    journal = {Journal of Open Source Software},
    publisher = {The Open Journal},
    volume = {9},
    number = {94},
    pages = {6188},
    doi = {10.21105/joss.06188}
}

Author

Simon Dirmeier sfyrbnd @ pm me

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

surjectors-0.3.3.tar.gz (189.0 kB view details)

Uploaded Source

Built Distribution

surjectors-0.3.3-py3-none-any.whl (49.3 kB view details)

Uploaded Python 3

File details

Details for the file surjectors-0.3.3.tar.gz.

File metadata

  • Download URL: surjectors-0.3.3.tar.gz
  • Upload date:
  • Size: 189.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for surjectors-0.3.3.tar.gz
Algorithm Hash digest
SHA256 21e7251328baf5be5c1f7eb5c5aefb0aa965877235c6bc4a76570d77a9bf3f1f
MD5 72ae99ec96d7678c34e3ef79024cceee
BLAKE2b-256 77af5a8196bb718fc244c95551a39be0764c2e655f6721e58c64bd8d171fade8

See more details on using hashes here.

File details

Details for the file surjectors-0.3.3-py3-none-any.whl.

File metadata

  • Download URL: surjectors-0.3.3-py3-none-any.whl
  • Upload date:
  • Size: 49.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for surjectors-0.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 844d60b46a82e23b410cd84c66928f0a59916ba9cd6f93e3effbd1d6444dfd52
MD5 2c53fdfe573cc3cd3586833832b6562c
BLAKE2b-256 8296b50bc6e5b0fcd611d768c553437edb43965c293ad2ee75478d75ac40681e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page