Causal Normalizing flows in PyTorch
Project description
Causal Normalizing Flows
[!warning] This is work in progress. You can expect bugs (yet we do not know of any) and rough edges.
CausalFlows is a Python package that implements Causal Normalizing Flows in PyTorch. As of now, it is essentially a wrapper of the Zuko library with a number of quality of life changes to improve its usability.
Citation
To cite this library, please cite the original manuscript that preceded it:
@article{javaloy2024causal,
title={Causal normalizing flows: from theory to practice},
author={Javaloy, Adri{\'a}n and S{\'a}nchez-Mart{\'\i}n, Pablo and Valera, Isabel},
journal={Advances in {Neural} {Information} {Processing} {Systems}},
volume={36},
year={2024}
}
Installation
The package is still not publicly available, so you need to install it locally from the source folder of this repository using
pip install -e .
Alternatively, you can install it directly from the repository.
pip install git+https://github.com/adrianjav/causal-flows
Getting started
Normalizing flows are provided in the flows module. To build one, supply the number of sample and context features as well as the transformations' hyperparameters. Then, feeding a context $c$ to the flow returns a conditional distribution $p(x | c)$ which can be evaluated and sampled from.
import torch
import causalflows
# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = causalflows.flows.CausalNSF(3, 5, order=(0, 1, 2), hidden_features=[128] * 3)
# Train to maximize the log-likelihood
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
for x, c in trainset:
loss = -flow(c).log_prob(x) # -log p(x | c)
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Sample 64 factual points x ~ p(x | c*)
x = flow(c_star).sample((64,))
# Intervene using the context manager (the context needs always to be given)
with flow(c_star).intervene(index=1, value=2.5) as int_flow:
x_int = int_flow.sample((64,))
# We could also sample with the helper method
x_int = flow(c_star).sample_interventional(index=1, value=2.5, sample_shape=(64,))
# And we can compute counterfactuals using the helper methods (or the context manager)
x_cf = flow(c_star).compute_counterfactual(x, index=1, value=2.5)
Alternatively, flows can be built as custom CausalFlow objects. As it can be appreciated in the snippet below, the library can be easily combined with custom flows from the Zuko library.
[!warning] Note that custom flows may not be causally consistent (i.e. they may have spurious correlations) if they are not carefully designed (see the original paper for an explanation).
from causalflows.flows import CausalFlow
from zuko.flows import UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform
flow = CausalFlow(
transform=[
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
],
base=UnconditionalDistribution(
DiagNormal,
torch.zeros(3),
torch.ones(3),
buffer=True,
),
)
For more information, check out the tutorials or the documentation.
References
Causal normalizing flows: from theory to practice (Javaloy et al., 2024)
NICE: Non-linear Independent Components Estimation (Dinh et al., 2014)
Variational Inference with Normalizing Flows (Rezende et al., 2015)
Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file causalflows-0.1.0.tar.gz.
File metadata
- Download URL: causalflows-0.1.0.tar.gz
- Upload date:
- Size: 14.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d286be1b2a9e31058ca741b37d35a82531f0aa5de31c3b4a9c35191cea84a409
|
|
| MD5 |
6bb94dcc544a22fcbf7ebcaa4447f9d1
|
|
| BLAKE2b-256 |
f5806dc00654feebea2bc196207f6c939568a9e10654286f378db021dc03211a
|
File details
Details for the file causalflows-0.1.0-py3-none-any.whl.
File metadata
- Download URL: causalflows-0.1.0-py3-none-any.whl
- Upload date:
- Size: 12.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
74dc0056e23ebe5af64c825f884725170fa52fb1fa718af0faf3d3037e42886a
|
|
| MD5 |
cf999858f10e58cf338a492c8b769617
|
|
| BLAKE2b-256 |
e04f53e1239e7c8d02b9b9be194643c1310eac0c4a107ddca21572384ccdf86b
|