Library for normalizing flows and neural flows
Project description
Stibor
Normalizing flows and neural flows for PyTorch.
- Normalizing flow defines a complicated probability density function as a transformation of the random variable.
- Neural flow defines continuous time dynamics with invertible neural networks.
Install package and dependencies
pip install stribor
Normalizing flows
Base densities
- Normal
st.Normal
andst.UnitNormal
andst.MultivariateNormal
- Uniform
st.UnitUniform
- Or, use distributions from
torch.distributions
Invertible transformations
- Activation functions
- ELU
st.ELU
- Leaky ReLU
st.LeakyReLU
- Sigmoid
st.Sigmoid
- Logit (inverse sigmoid)
st.Logit
- ELU
- Affine
- Element-wise transformation
st.Affine
- Linear layer with LU factorization
st.AffineLU
- Matrix exponential
st.MatrixExponential
- Element-wise transformation
- Coupling layer that can be combined with any element-wise transformation
st.Coupling
- Continuous normalizing flows
st.ContinuousTransform
- Differential equations with stochastic trace estimation:
st.net.DiffeqMLP
st.net.DiffeqDeepset
st.net.DiffeqSelfAttention
- Differential equations with fixed zero trace:
st.net.DiffeqZeroTraceMLP
st.net.DiffeqZeroTraceDeepSet
st.net.DiffeqZeroTraceAttention
- Differential equations with exact trace computation:
st.net.DiffeqExactTraceMLP
st.net.DiffeqExactTraceDeepSet
st.net.DiffeqExactTraceAttention
- Differential equations with stochastic trace estimation:
- Cummulative sum
st.Cumsum
and differencest.Diff
- Across single column
st.CumsumColumn
andst.DiffColumn
- Across single column
- Permutations
- Flipping the indices
st.Flip
- Random permutation of indices
st.Permute
- Flipping the indices
- Spline (quadratic or cubic) element-wise transformation
st.Spline
Example: Normalizing flow
To define a normalizing flow, define a base distribution and a series of transformations, e.g.:
import stribor as st
import torch
dim = 2
base_dist = st.UnitNormal(dim)
transforms = [
st.Coupling(
transform=st.Affine(dim, latent_net=st.net.MLP(dim, [64], 2 * dim)),
mask='ordered_right_half',
),
st.ContinuousTransform(
dim,
net=st.net.DiffeqMLP(dim + 1, [64], dim),
)
]
flow = st.NormalizingFlow(base_dist, transforms)
x = torch.rand(10, dim)
y = flow(x) # Forward transformation
log_prob = flow.log_prob(y) # Log-probability p(y)
Example: Neural flow
Neural flows are defined similarly but now we don't need the base density and all the invertible transformations must depend on time. In particular, at t=0
, the transformation becomes an identity.
import torch
import stribor as st
dim = 2
f = st.NeuralFlow([
st.ContinuousAffineCoupling(
latent_net=st.net.MLP(dim, [32], 2 * dim),
time_net=st.net.TimeLinear(dim),
mask='ordered_0',
concatenate_time=False,
),
])
x = torch.randn(10, 4, dim)
t = torch.randn_like(x[...,:1])
y = f(x, t=t) # Outputs the same dimension as x
Run tests
pytest --pyargs stribor
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
stribor-0.2.0.tar.gz
(35.8 kB
view details)
Built Distribution
stribor-0.2.0-py3-none-any.whl
(57.7 kB
view details)
File details
Details for the file stribor-0.2.0.tar.gz
.
File metadata
- Download URL: stribor-0.2.0.tar.gz
- Upload date:
- Size: 35.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.7.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d6941d681b0975e9dbb944c0cb220c2920e6979264b8475067ddd792fe537507 |
|
MD5 | 2147f6aafa015f02f3f7a3edfa51f7f2 |
|
BLAKE2b-256 | 2bc486f4fdfa7d53a0205e235dfcf468b0b924b3c2b623a7226426ab97664d75 |
File details
Details for the file stribor-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: stribor-0.2.0-py3-none-any.whl
- Upload date:
- Size: 57.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.7.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 468e7029e8dae90a1952ecc23874aaf177e622b602f80e1ef710cb583e3f0c26 |
|
MD5 | 2129502e8b608966b0eadcdc4ae4264f |
|
BLAKE2b-256 | 1efed050225996c44f7d9292b9d02a2cc18565edf698ca760d8718bb53967ac8 |