Skip to main content

Normalizing flows in PyTorch

Project description

Zuko's banner

Zuko - Normalizing flows in PyTorch

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution and Transform classes of torch are not sub-classes of torch.nn.Module, which means you cannot send their internal tensors to GPU with .to('cuda') or retrieve their parameters with .parameters().

To solve this problem, zuko defines two abstract classes: DistributionModule and TransformModule. The former is any Module whose forward pass returns a Distribution and the latter is any Module whose forward pass returns a Transform. Then, a normalizing flow is the composition of a list of TransformModule and a base DistributionModule. This design allows for flows that behave like distributions while retaining the benefits of Module. It also makes the implementations easy to understand and extend.

In the Avatar cartoon, Zuko is a powerful firebender 🔥

Installation

The zuko package is available on PyPI, which means it is installable via pip.

pip install zuko

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/zuko

Getting started

Normalizing flows are provided in the zuko.flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context y to the flow returns a conditional distribution p(x | y) which can be evaluated and sampled from.

import torch
import zuko

x = torch.randn(3)
y = torch.randn(5)

# Neural spline flow (NSF) with 3 transformations
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)

# Evaluate log p(x | y)
log_p = flow(y).log_prob(x)

# Sample 64 points x ~ p(x | y)
x = flow(y).sample((64,))

For more information about the available features check out the documentation at francois-rozet.github.io/zuko.

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zuko-0.0.4.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zuko-0.0.4-py3-none-any.whl (18.5 kB view details)

Uploaded Python 3

File details

Details for the file zuko-0.0.4.tar.gz.

File metadata

  • Download URL: zuko-0.0.4.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for zuko-0.0.4.tar.gz
Algorithm Hash digest
SHA256 a4876e865393f46d29a84c9bd549e35215af5232c2f8b03c0ddf1ae0e4c8e3ea
MD5 03cb8e580c9f457b69e3b758a6399c9f
BLAKE2b-256 c1325e503163ad34d2dfbac28825a03ebe41a7ab03fc075c5167d257737b200f

See more details on using hashes here.

File details

Details for the file zuko-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: zuko-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 18.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for zuko-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 7a341ef302d7d7ffd30ef58c11339dee35d98b4c8f2b77abc09fef695673dc1b
MD5 05c7689c7a25cec9d4a89f73c1bc095a
BLAKE2b-256 15f273b55cc9d46f71b8a69450d03ab4ae29b5e5f78ef0d53f5ed5e7d2ad60c6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page