Library for normalizing flows and neural flows
Project description
Stibor
Package to easily define normalizing flows and neural flows for Pytorch.
- Normalizing flows define complicated high-dimensional densities as transformations of random variables.
- Neural flows define continuous time dynamics with invertible neural networks.
Install package and dependencies
pip install git+https://github.com/mbilos/stribor.git
Normalizing flows
Base densities
- Normal
st.Normal
andst.UnitNormal
andst.MultivariateNormal
- Uniform
st.UnitUniform
- Other distributions from
torch.distributions
Invertible transformations
- Activation functions
- ELU
st.ELU
- Leaky ReLU
st.LeakyReLU
- ELU
- Affine
- Element-wise transformation
st.Affine
- Fixed (non-learnable) element-wise transformation
st.AffineFixed
- Linear layer with PLU factorization
st.AffinePLU
- Matrix exponential
st.MatrixExponential
- Element-wise transformation
- Coupling layer that can be combined with any element-wise transformation
st.Coupling
- Continuous normalizing flows
st.ContinuousNormalizingFlow
- Differential equations with stochastic trace estimation:
st.net.DiffeqMLP
st.net.DiffeqDeepset
st.net.DiffeqSelfAttention
- Differential equations with fixed zero trace:
st.net.DiffeqZeroTraceMLP
st.net.DiffeqZeroTraceDeepSet
st.net.DiffeqZeroTraceAttention
- Differential equations with exact trace computation:
st.net.DiffeqExactTraceMLP
st.net.DiffeqExactTraceDeepSet
st.net.DiffeqExactTraceAttention
- Differential equations with stochastic trace estimation:
- Cummulative sum
st.Cumsum
and differencest.Diff
- Across single column
st.CumsumColumn
andst.DiffColumn
- Across single column
- Permutations
- Flipping the indices
st.Flip
- Random permutation of indices
st.Permute
- Flipping the indices
- Sigmoid
st.Sigmoid
and logitst.Logit
function - Spline (quadratic or cubic) element-wise transformation
st.Spline
Example
To define a normalizing flow, define a base distribution and a series of transformations, e.g.:
import stribor as st
import torch
dim = 2
base_dist = st.UnitNormal(dim)
transforms = [
st.Coupling(
flow=st.Affine(dim, latent_net=st.net.MLP(dim, [64], dim)),
mask='ordered_right_half'
),
st.ContinuousNormalizingFlow(
dim,
net=st.net.DiffeqMLP(dim + 1, [64], dim)
)
]
flow = st.Flow(base_dist, transforms)
x = torch.rand(1, dim)
y, ljd = flow(x)
y_inv, ljd_inv = flow.inverse(y)
Run tests
pytest --pyargs stribor
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
stribor-0.1.0.tar.gz
(33.6 kB
view details)
Built Distribution
stribor-0.1.0-py3-none-any.whl
(54.9 kB
view details)
File details
Details for the file stribor-0.1.0.tar.gz
.
File metadata
- Download URL: stribor-0.1.0.tar.gz
- Upload date:
- Size: 33.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1d41d0a587def0697550b052ceb4088616c9f2952bd011cc8db9bcb4ed71cb64 |
|
MD5 | 51579dd2fcd28e3adaa1b37f86d5d558 |
|
BLAKE2b-256 | 78b75917ec34ca9ca88eb4d7aebec5f5a1e177b2bc44971cd38bbbb64c53e40d |
File details
Details for the file stribor-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: stribor-0.1.0-py3-none-any.whl
- Upload date:
- Size: 54.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d7f39eedebd84ba3f3b6e4594679d5c98eab25dc1362b1dc1c6a821904fcd1fb |
|
MD5 | 1bad2cd4c00c6de34072dbed7fce339d |
|
BLAKE2b-256 | d1afe684d4c4d05f874b16235cb1012d735977a1efb5df596dabd6f66e6fc20f |