Skip to main content

Normalizing flows in PyTorch

Project description

Zuko's banner

Zuko - Normalizing flows in PyTorch

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution and Transform classes of torch are not sub-classes of torch.nn.Module, which means you cannot send their internal tensors to GPU with .to('cuda') or retrieve their parameters with .parameters().

To solve this problem, zuko defines two abstract classes: DistributionModule and TransformModule. The former is any Module whose forward pass returns a Distribution and the latter is any Module whose forward pass returns a Transform. Then, a normalizing flow is the composition of a list of TransformModule and a base DistributionModule. This design allows for flows that behave like distributions while retaining the benefits of Module. It also makes the implementations easy to understand and extend.

In the Avatar cartoon, Zuko is a powerful firebender 🔥

Installation

The zuko package is available on PyPI, which means it is installable via pip.

pip install zuko

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/zuko

Getting started

Normalizing flows are provided in the zuko.flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context y to the flow returns a conditional distribution p(x | y) which can be evaluated and sampled from.

import torch
import zuko

x = torch.randn(3)
y = torch.randn(5)

# Neural spline flow (NSF) with 3 transformations
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)

# Evaluate log p(x | y)
log_p = flow(y).log_prob(x)

# Sample 64 points x ~ p(x | y)
x = flow(y).sample((64,))

For more information about the available features check out the documentation at francois-rozet.github.io/zuko.

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zuko-0.0.3.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zuko-0.0.3-py3-none-any.whl (18.5 kB view details)

Uploaded Python 3

File details

Details for the file zuko-0.0.3.tar.gz.

File metadata

  • Download URL: zuko-0.0.3.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for zuko-0.0.3.tar.gz
Algorithm Hash digest
SHA256 9c11edcc18ad07be21aaceb32e6e3f7d63f6506e00ab45288d610b863e208bd4
MD5 ef7752572db1685af5b8c45b3de5f301
BLAKE2b-256 aef50a780a29c7c389b5cca4dfd46abbe912a12ffb6443e44606e893657310f2

See more details on using hashes here.

File details

Details for the file zuko-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: zuko-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 18.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for zuko-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 afa358ec84d6f1026587ef38fd9a795fd935ae6c55a143c1db1b2e59aef40781
MD5 b5c4aaa7927d476163da8eb732083c40
BLAKE2b-256 ec89138ef6b16ab451354ff497adc33a9aa0b6a8f23df54d7ac995985d16e040

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page