Skip to main content

Library for normalizing flows and neural flows

Project description

Stibor

Package to easily define normalizing flows and neural flows for Pytorch.

  • Normalizing flows define complicated high-dimensional densities as transformations of random variables.
  • Neural flows define continuous time dynamics with invertible neural networks.

Install package and dependencies

pip install git+https://github.com/mbilos/stribor.git

Normalizing flows

Base densities

  • Normal st.Normal and st.UnitNormal and st.MultivariateNormal
  • Uniform st.UnitUniform
  • Other distributions from torch.distributions

Invertible transformations

  • Activation functions
    • ELU st.ELU
    • Leaky ReLU st.LeakyReLU
  • Affine
    • Element-wise transformation st.Affine
    • Fixed (non-learnable) element-wise transformation st.AffineFixed
    • Linear layer with PLU factorization st.AffinePLU
    • Matrix exponential st.MatrixExponential
  • Coupling layer that can be combined with any element-wise transformation st.Coupling
  • Continuous normalizing flows st.ContinuousNormalizingFlow
    • Differential equations with stochastic trace estimation:
      • st.net.DiffeqMLP
      • st.net.DiffeqDeepset
      • st.net.DiffeqSelfAttention
    • Differential equations with fixed zero trace:
      • st.net.DiffeqZeroTraceMLP
      • st.net.DiffeqZeroTraceDeepSet
      • st.net.DiffeqZeroTraceAttention
    • Differential equations with exact trace computation:
      • st.net.DiffeqExactTraceMLP
      • st.net.DiffeqExactTraceDeepSet
      • st.net.DiffeqExactTraceAttention
  • Cummulative sum st.Cumsum and difference st.Diff
    • Across single column st.CumsumColumn and st.DiffColumn
  • Permutations
    • Flipping the indices st.Flip
    • Random permutation of indices st.Permute
  • Sigmoid st.Sigmoid and logit st.Logit function
  • Spline (quadratic or cubic) element-wise transformation st.Spline

Example

To define a normalizing flow, define a base distribution and a series of transformations, e.g.:

import stribor as st
import torch

dim = 2
base_dist = st.UnitNormal(dim)

transforms = [
    st.Coupling(
        flow=st.Affine(dim, latent_net=st.net.MLP(dim, [64], dim)),
        mask='ordered_right_half'
    ),
    st.ContinuousNormalizingFlow(
        dim,
        net=st.net.DiffeqMLP(dim + 1, [64], dim)
    )
]

flow = st.Flow(base_dist, transforms)

x = torch.rand(1, dim)
y, ljd = flow(x)
y_inv, ljd_inv = flow.inverse(y)

Run tests

pytest --pyargs stribor

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stribor-0.1.0.tar.gz (33.6 kB view details)

Uploaded Source

Built Distribution

stribor-0.1.0-py3-none-any.whl (54.9 kB view details)

Uploaded Python 3

File details

Details for the file stribor-0.1.0.tar.gz.

File metadata

  • Download URL: stribor-0.1.0.tar.gz
  • Upload date:
  • Size: 33.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10

File hashes

Hashes for stribor-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1d41d0a587def0697550b052ceb4088616c9f2952bd011cc8db9bcb4ed71cb64
MD5 51579dd2fcd28e3adaa1b37f86d5d558
BLAKE2b-256 78b75917ec34ca9ca88eb4d7aebec5f5a1e177b2bc44971cd38bbbb64c53e40d

See more details on using hashes here.

File details

Details for the file stribor-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: stribor-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 54.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10

File hashes

Hashes for stribor-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d7f39eedebd84ba3f3b6e4594679d5c98eab25dc1362b1dc1c6a821904fcd1fb
MD5 1bad2cd4c00c6de34072dbed7fce339d
BLAKE2b-256 d1afe684d4c4d05f874b16235cb1012d735977a1efb5df596dabd6f66e6fc20f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page