Skip to main content

Normalizing flows in PyTorch

Project description

Torchflows: normalizing flows in PyTorch

Torchflows is a library for generative modeling and density estimation using normalizing flows. It implements many normalizing flow architectures and their building blocks for:

  • Easy use of normalizing flows as trainable distributions.
  • Easy implementation of new normalizing flows.

Example use:

import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP

torch.manual_seed(0)

n_data = 1000
n_dim = 3

x = torch.randn(n_data, n_dim)  # Generate some training data
bijection = RealNVP(n_dim)  # Create the bijection
flow = Flow(bijection)  # Create the normalizing flow

flow.fit(x)  # Fit the normalizing flow to training data
log_prob = flow.log_prob(x)  # Compute the log probability of training data
x_new = flow.sample(50)  # Sample 50 new data points

print(log_prob.shape)  # (100,)
print(x_new.shape)  # (50, 3)

Check out examples and the documentation, including the list of supported architectures.

Installing

We support Python versions 3.7 and upwards.

Install Torchflows via pip:

pip install torchflows

Install Torchflows directly from Github:

pip install git+https://github.com/davidnabergoj/torchflows.git

Setup for development:

git clone https://github.com/davidnabergoj/torchflows.git
cd torchflows
pip install -r requirements.txt

Citation

If you use this code in your work, we kindly ask that you cite the accompanying paper:

Nabergoj and Štrumbelj: Empirical evaluation of normalizing flows in Markov Chain Monte Carlo, 2024. arxiv:2412.17136.

BibTex entry:

@misc{nabergoj_nf_mcmc_evaluation_2024,
    author = {Nabergoj, David and \v{S}trumbelj, Erik},
    title = {Empirical evaluation of normalizing flows in {Markov} {Chain} {Monte} {Carlo}},
    publisher = {arXiv},
    month = dec,
    year = {2024},
    note = {arxiv:2412.17136}
}

Contributions

We warmly welcome all contributions and comments. Please do not hesitate to submit issues and pull requests.

Some options to start contributing include:

  • Adding references to the documentation page for architecture presets.
  • Implementing new normalizing flow architectures (see the developer guide).
  • Adding more automated tests for numerical stability and optimization.
  • Adding docstrings to undocumented classes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflows-1.2.0.tar.gz (82.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchflows-1.2.0-py3-none-any.whl (102.3 kB view details)

Uploaded Python 3

File details

Details for the file torchflows-1.2.0.tar.gz.

File metadata

  • Download URL: torchflows-1.2.0.tar.gz
  • Upload date:
  • Size: 82.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torchflows-1.2.0.tar.gz
Algorithm Hash digest
SHA256 4e85ce1d66186eb1bd8c07af175fac21ee7c3408f519dea20e02bdb10d1b007e
MD5 fdb232749ebc815a1bb17cde1dfeacc9
BLAKE2b-256 d79739bd390e0e820e56efbdc82e4a4619631f91a50d899dd73e3456ce33c0f4

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchflows-1.2.0.tar.gz:

Publisher: python-publish.yml on davidnabergoj/torchflows

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torchflows-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchflows-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 102.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torchflows-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 170bb1d098429b8c3024ec57c8c0ac42a3e6213385470debbf66b2ee8d4bae0f
MD5 94c2aca5d684c00fa482e081fbde362c
BLAKE2b-256 40d81cd8fba083bfee8bba49ffb16d9f24f79a38723b2cf27987d762f811c287

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchflows-1.2.0-py3-none-any.whl:

Publisher: python-publish.yml on davidnabergoj/torchflows

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page