Skip to main content

Causal Normalizing flows in PyTorch

Project description

Causal Normalizing Flows

[!warning] This is work in progress. You can expect bugs (yet we do not know of any) and rough edges.

CausalFlows is a Python package that implements Causal Normalizing Flows in PyTorch. As of now, it is essentially a wrapper of the Zuko library with a number of quality of life changes to improve its usability.

Citation

To cite this library, please cite the original manuscript that preceded it:

@article{javaloy2024causal,
    title={Causal normalizing flows: from theory to practice},
    author={Javaloy, Adri{\'a}n and S{\'a}nchez-Mart{\'\i}n, Pablo and Valera, Isabel},
    journal={Advances in {Neural} {Information} {Processing} {Systems}},
    volume={36},
    year={2024}
}

Installation

The package is still not publicly available, so you need to install it locally from the source folder of this repository using

pip install -e .

Alternatively, you can install it directly from the repository.

pip install git+https://github.com/adrianjav/causal-flows

Getting started

Normalizing flows are provided in the flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context $c$ to the flow returns a conditional distribution $p(x | c)$ which can be evaluated and sampled from.

import torch
import causalflows

# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = causalflows.flows.CausalNSF(3, 5, order=(0, 1, 2), hidden_features=[128] * 3)

# Train to maximize the log-likelihood
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

for x, c in trainset:
    loss = -flow(c).log_prob(x)  # -log p(x | c)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Sample 64 factual points x ~ p(x | c*)
x = flow(c_star).sample((64,))

# Intervene using the context manager (the context needs always to be given)
with flow(c_star).intervene(index=1, value=2.5) as int_flow:
    x_int = int_flow.sample((64,))

# We could also sample with the helper method
x_int = flow(c_star).sample_interventional(index=1, value=2.5, sample_shape=(64,))

# And we can compute counterfactuals using the helper methods (or the context manager)
x_cf = flow(c_star).compute_counterfactual(x, index=1, value=2.5)

Alternatively, flows can be built as custom CausalFlow objects. As it can be appreciated in the snippet below, the library can be easily combined with custom flows from the Zuko library.

[!warning] Note that custom flows may not be causally consistent (i.e. they may have spurious correlations) if they are not carefully designed (see the original paper for an explanation).

from causalflows.flows import CausalFlow
from zuko.flows import UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform

flow = CausalFlow(
    transform=[
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
        UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
        MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
    ],
    base=UnconditionalDistribution(
        DiagNormal,
        torch.zeros(3),
        torch.ones(3),
        buffer=True,
    ),
)

For more information, check out the tutorials or the documentation.

References

Causal normalizing flows: from theory to practice (Javaloy et al., 2024)

NICE: Non-linear Independent Components Estimation (Dinh et al., 2014)

Variational Inference with Normalizing Flows (Rezende et al., 2015)

Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)

Neural Spline Flows (Durkan et al., 2019)

Neural Autoregressive Flows (Huang et al., 2018)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

causalflows-0.1.0.tar.gz (14.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

causalflows-0.1.0-py3-none-any.whl (12.7 kB view details)

Uploaded Python 3

File details

Details for the file causalflows-0.1.0.tar.gz.

File metadata

  • Download URL: causalflows-0.1.0.tar.gz
  • Upload date:
  • Size: 14.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for causalflows-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d286be1b2a9e31058ca741b37d35a82531f0aa5de31c3b4a9c35191cea84a409
MD5 6bb94dcc544a22fcbf7ebcaa4447f9d1
BLAKE2b-256 f5806dc00654feebea2bc196207f6c939568a9e10654286f378db021dc03211a

See more details on using hashes here.

File details

Details for the file causalflows-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: causalflows-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for causalflows-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 74dc0056e23ebe5af64c825f884725170fa52fb1fa718af0faf3d3037e42886a
MD5 cf999858f10e58cf338a492c8b769617
BLAKE2b-256 e04f53e1239e7c8d02b9b9be194643c1310eac0c4a107ddca21572384ccdf86b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page