Normalizing flows in PyTorch
Project description
Torchflows: normalizing flows in PyTorch
Torchflows is a library for generative modeling and density estimation using normalizing flows. It implements many normalizing flow architectures and their building blocks for:
- easy use of normalizing flows as trainable distributions;
- easy implementation of new normalizing flows.
Example use:
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
torch.manual_seed(0)
n_data = 1000
n_dim = 3
x = torch.randn(n_data, n_dim) # Generate some training data
bijection = RealNVP(n_dim) # Create the bijection
flow = Flow(bijection) # Create the normalizing flow
flow.fit(x) # Fit the normalizing flow to training data
log_prob = flow.log_prob(x) # Compute the log probability of training data
x_new = flow.sample(50) # Sample 50 new data points
print(log_prob.shape) # (100,)
print(x_new.shape) # (50, 3)
Check examples and documentation, including the list of supported architectures here. We also provide examples here.
Installing
We support Python versions 3.7 and upwards.
Install Torchflows via pip:
pip install torchflows
Install Torchflows directly from Github:
pip install git+https://github.com/davidnabergoj/torchflows.git
Setup for development:
git clone https://github.com/davidnabergoj/torchflows.git
cd torchflows
pip install -r requirements.txt
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchflows-1.1.0.tar.gz
(75.8 kB
view hashes)
Built Distribution
torchflows-1.1.0-py3-none-any.whl
(92.9 kB
view hashes)
Close
Hashes for torchflows-1.1.0-py3-none-any.whl
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 | 03cc117c89c65e513bd5418e4991c229a64fd466e8aa52ceccbea47b490a9208 |
|
| MD5 | b342c81d4ac85ce18a0fa176c79f6424 |
|
| BLAKE2b-256 | eb8415fdcfa1a93b9268f559dfb5a86613f9e42f235488795859050ca3436ff8 |