Skip to main content

Normalizing flows in PyTorch

Project description

Torchflows: normalizing flows in PyTorch

Torchflows is a library for generative modeling and density estimation using normalizing flows. It implements many normalizing flow architectures and their building blocks for:

  • easy use of normalizing flows as trainable distributions;
  • easy implementation of new normalizing flows.

Example use:

import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP

torch.manual_seed(0)

n_data = 1000
n_dim = 3

x = torch.randn(n_data, n_dim)  # Generate some training data
bijection = RealNVP(n_dim)  # Create the bijection
flow = Flow(bijection)  # Create the normalizing flow

flow.fit(x)  # Fit the normalizing flow to training data
log_prob = flow.log_prob(x)  # Compute the log probability of training data
x_new = flow.sample(50)  # Sample 50 new data points

print(log_prob.shape)  # (100,)
print(x_new.shape)  # (50, 3)

Check examples and documentation, including the list of supported architectures here. We also provide examples here.

Installing

We support Python versions 3.7 and upwards.

Install Torchflows via pip:

pip install torchflows

Install Torchflows directly from Github:

pip install git+https://github.com/davidnabergoj/torchflows.git

Setup for development:

git clone https://github.com/davidnabergoj/torchflows.git
cd torchflows
pip install -r requirements.txt

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflows-1.1.0.tar.gz (75.8 kB view hashes)

Uploaded Source

Built Distribution

torchflows-1.1.0-py3-none-any.whl (92.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page