Library for normalizing flows and neural flows
Project description
Stibor
Normalizing flows and neural flows for PyTorch.
- Normalizing flow defines a complicated probability density function as a transformation of the random variable.
- Neural flow defines continuous time dynamics with invertible neural networks.
Install package and dependencies
pip install stribor
Normalizing flows
Base densities
- Normal
st.Normalandst.UnitNormalandst.MultivariateNormal - Uniform
st.UnitUniform - Or, use distributions from
torch.distributions
Invertible transformations
- Activation functions
- ELU
st.ELU - Leaky ReLU
st.LeakyReLU - Sigmoid
st.Sigmoid - Logit (inverse sigmoid)
st.Logit
- ELU
- Affine
- Element-wise transformation
st.Affine - Linear layer with LU factorization
st.AffineLU - Matrix exponential
st.MatrixExponential
- Element-wise transformation
- Coupling layer that can be combined with any element-wise transformation
st.Coupling - Continuous normalizing flows
st.ContinuousTransform- Differential equations with stochastic trace estimation:
st.net.DiffeqMLPst.net.DiffeqDeepsetst.net.DiffeqSelfAttention
- Differential equations with fixed zero trace:
st.net.DiffeqZeroTraceMLPst.net.DiffeqZeroTraceDeepSetst.net.DiffeqZeroTraceAttention
- Differential equations with exact trace computation:
st.net.DiffeqExactTraceMLPst.net.DiffeqExactTraceDeepSetst.net.DiffeqExactTraceAttention
- Differential equations with stochastic trace estimation:
- Cummulative sum
st.Cumsumand differencest.Diff- Across single column
st.CumsumColumnandst.DiffColumn
- Across single column
- Permutations
- Flipping the indices
st.Flip - Random permutation of indices
st.Permute
- Flipping the indices
- Spline (quadratic or cubic) element-wise transformation
st.Spline
Example: Normalizing flow
To define a normalizing flow, define a base distribution and a series of transformations, e.g.:
import stribor as st
import torch
dim = 2
base_dist = st.UnitNormal(dim)
transforms = [
st.Coupling(
transform=st.Affine(dim, latent_net=st.net.MLP(dim, [64], 2 * dim)),
mask='ordered_right_half',
),
st.ContinuousTransform(
dim,
net=st.net.DiffeqMLP(dim + 1, [64], dim),
)
]
flow = st.NormalizingFlow(base_dist, transforms)
x = torch.rand(10, dim)
y = flow(x) # Forward transformation
log_prob = flow.log_prob(y) # Log-probability p(y)
Example: Neural flow
Neural flows are defined similarly but now we don't need the base density and all the invertible transformations must depend on time. In particular, at t=0, the transformation becomes an identity.
import torch
import stribor as st
dim = 2
f = st.NeuralFlow([
st.ContinuousAffineCoupling(
latent_net=st.net.MLP(dim, [32], 2 * dim),
time_net=st.net.TimeLinear(dim),
mask='ordered_0',
concatenate_time=False,
),
])
x = torch.randn(10, 4, dim)
t = torch.randn_like(x[...,:1])
y = f(x, t=t) # Outputs the same dimension as x
Run tests
pytest --pyargs stribor
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file stribor-0.2.0.tar.gz.
File metadata
- Download URL: stribor-0.2.0.tar.gz
- Upload date:
- Size: 35.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.7.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d6941d681b0975e9dbb944c0cb220c2920e6979264b8475067ddd792fe537507
|
|
| MD5 |
2147f6aafa015f02f3f7a3edfa51f7f2
|
|
| BLAKE2b-256 |
2bc486f4fdfa7d53a0205e235dfcf468b0b924b3c2b623a7226426ab97664d75
|
File details
Details for the file stribor-0.2.0-py3-none-any.whl.
File metadata
- Download URL: stribor-0.2.0-py3-none-any.whl
- Upload date:
- Size: 57.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.7.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
468e7029e8dae90a1952ecc23874aaf177e622b602f80e1ef710cb583e3f0c26
|
|
| MD5 |
2129502e8b608966b0eadcdc4ae4264f
|
|
| BLAKE2b-256 |
1efed050225996c44f7d9292b9d02a2cc18565edf698ca760d8718bb53967ac8
|