Skip to main content

Surjection layers for density estimation with normalizing flows

Project description

surjectors

ci version doi

About

Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. Surjectors makes use of

  • Haiku`s module system for neural networks,
  • Distrax for probability distributions and some base bijectors,
  • Optax for gradient-based optimization,
  • JAX for autodiff and XLA computation.

Examples

You can, for instance, construct a simple normalizing flow like this:

import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp

def decoder_fn(n_dim):
    def _fn(z):
        params = make_mlp([32, 32, n_dim * 2])(z)
        means, log_scales = jnp.split(params, 2, -1)
        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
    return _fn

@hk.without_apply_rng
@hk.transform
def flow(x):
    base_distribution = distrax.Independent(
        distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
    )
    transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
    pushforward = TransformedDistribution(base_distribution, transform)
    return pushforward.log_prob(x)

x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install the package from PyPI, call:

pip install surjectors

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/dirmeier/surjectors@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone surjectors and install uv from here.

  2. Install all dependencies using uv sync --all-groups.

  3. Install pre-commit and gitlint via:

    pre-commit install
    gitlint install-hook
    
  4. Create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug.

  5. Implement your contribution and ideally a test case.

  6. Test it by calling make tests, make lints and make format on the (Unix) command line.

  7. Submit a PR 🙂.

    pre-commit install
    gitlint install-hook
    
  8. Implement your contribution and ideally a test case.

  9. Test it by calling make format, make lints and make tests on the (Unix) command line.

  10. Submit a PR 🙂.

Citing Surjectors

If you find our work relevant to your research, please consider citing:

@article{dirmeier2024surjectors,
    author = {Simon Dirmeier},
    title = {Surjectors: surjection layers for density estimation with normalizing flows},
    year = {2024},
    journal = {Journal of Open Source Software},
    publisher = {The Open Journal},
    volume = {9},
    number = {94},
    pages = {6188},
    doi = {10.21105/joss.06188}
}

Author

Simon Dirmeier simd23 @ pm me

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

surjectors-0.3.4.tar.gz (318.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

surjectors-0.3.4-py3-none-any.whl (50.9 kB view details)

Uploaded Python 3

File details

Details for the file surjectors-0.3.4.tar.gz.

File metadata

  • Download URL: surjectors-0.3.4.tar.gz
  • Upload date:
  • Size: 318.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for surjectors-0.3.4.tar.gz
Algorithm Hash digest
SHA256 bcf2d854f152e6a6040b23bbe8330244ec79db6889d9503488821b3793c00fe0
MD5 60747b6daf02300783e2790f87c89a36
BLAKE2b-256 9c875cb51f0a9a3da817af1f05e3966be536e79b7102d243ddafb025b1e4399c

See more details on using hashes here.

File details

Details for the file surjectors-0.3.4-py3-none-any.whl.

File metadata

  • Download URL: surjectors-0.3.4-py3-none-any.whl
  • Upload date:
  • Size: 50.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for surjectors-0.3.4-py3-none-any.whl
Algorithm Hash digest
SHA256 0c5a1e1b5ff13b04f2a580345ec6ef7af9d20645b523c2f27d5d73abbf8f4580
MD5 e9007590bc504a939b294d7aca81839f
BLAKE2b-256 1b94d11f0519c4ae6c835041fbf70f01a25cbba7aba73941cca719fcb7847b5f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page