Skip to main content

Normalizing flows in PyTorch

Project description

Zuko's banner

Zuko - Normalizing flows in PyTorch

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution and Transform classes of torch are not sub-classes of torch.nn.Module, which means you cannot send their internal tensors to GPU with .to('cuda') or retrieve their parameters with .parameters().

To solve this problem, zuko defines two abstract classes: DistributionModule and TransformModule. The former is any Module whose forward pass returns a Distribution and the latter is any Module whose forward pass returns a Transform. A normalizing flow is just a DistributionModule which contains a list of TransformModule and a base DistributionModule. This design allows for flows that behave like distributions while retaining the benefits of Module. It also makes the implementations easier to understand and extend.

In the Avatar cartoon, Zuko is a powerful firebender 🔥

Installation

The zuko package is available on PyPI, which means it is installable via pip.

pip install zuko

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/zuko

Getting started

Normalizing flows are provided in the zuko.flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context $c$ to the flow returns a conditional distribution $p(x | c)$ which can be evaluated and sampled from.

import torch
import zuko

# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)

# Train to maximize the log-likelihood
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)

for x, c in trainset:
    loss = -flow(c).log_prob(x)  # -log p(x | c)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Sample 64 points x ~ p(x | c*)
x = flow(c_star).sample((64,))

Alternatively, flows can be built as custom FlowModule objects.

from zuko.flows import FlowModule, MaskedAutoregressiveTransform, Unconditional
from zuko.distributions import DiagNormal
from zuko.transforms import PermutationTransform

flow = FlowModule(
    transforms=[
        MaskedAutoregressiveTransform(3, 5, hidden_features=[128] * 3),
        Unconditional(PermutationTransform, torch.randperm(3), buffer=True),
        MaskedAutoregressiveTransform(3, 5, hidden_features=[128] * 3),
    ],
    base=Unconditional(
        DiagNormal,
        torch.zeros(3),
        torch.ones(3),
        buffer=True,
    ),
)

For more information, check out the documentation at zuko.readthedocs.io.

Available flows

Class Year Reference
MAF 2017 Masked Autoregressive Flow for Density Estimation
NSF 2019 Neural Spline Flows
NCSF 2020 Normalizing Flows on Tori and Spheres
SOSPF 2019 Sum-of-Squares Polynomial Flow
NAF 2018 Neural Autoregressive Flows
UNAF 2019 Unconstrained Monotonic Neural Networks
CNF 2018 Neural Ordinary Differential Equations

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zuko-0.2.2.tar.gz (28.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zuko-0.2.2-py3-none-any.whl (27.1 kB view details)

Uploaded Python 3

File details

Details for the file zuko-0.2.2.tar.gz.

File metadata

  • Download URL: zuko-0.2.2.tar.gz
  • Upload date:
  • Size: 28.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for zuko-0.2.2.tar.gz
Algorithm Hash digest
SHA256 cb224d8d4d2985ef8ee62dfff269e056bc1af885d3a0fad6ba9050344b428bfd
MD5 b5abed6f48a67718eaa90eae18d8fb12
BLAKE2b-256 6a5dc93442b488cf51b6f14a1599263007538c114cbcd53ed676b8bdda46138d

See more details on using hashes here.

File details

Details for the file zuko-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: zuko-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 27.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for zuko-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e251420a32c31f5df106aca99c9681f34b2c4f98657428e46a86e88d02aec55c
MD5 c8f38c4cc568a9b854e7dc702bc62699
BLAKE2b-256 6fed5878bec22497209f463612acd7e504ce6edafea7f05232388c935cfca16c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page