Amortized Causal Effect Estimation via In-Context Learning
Project description
CausalPFN: Amortized Causal Effect Estimation via In-Context Learning
An easy-to-use library for causal effect estimation using transformer-based in-context learning
🛠️ Installation • 🚀 Quick Start • 📊 Examples • 🔬 Reproducibility
🌟 Overview
CausalPFN leverages the power of transformer architectures for amortized causal effect estimation, enabling fast and accurate inference across diverse causal scenarios without the need for retraining. Our approach combines the flexibility of in-context learning with the rigor of causal inference.
✨ Key Features
- 🚀 Fast Inference: Amortized learning enables rapid causal effect estimation without retraining
- 🧮 Uncertainty Quantification: Built-in calibration and confidence estimation
- ⚡ GPU Accelerated: Optimized for modern hardware with CUDA support
- 📈 Benchmarked: Competitive performance against state-of-the-art causal inference methods
- 📊 Uplift-Modelling: Supports treatment effect estimation for personalized decision-making in real-world applications
Installation
Via PyPI
pip install causalpfn
Requirements
- Python 3.10+
- PyTorch 2.3+
- NumPy
- scikit-learn
- tqdm
- faiss-cpu
- huggingface_hub
Quick Start
Here's a complete example demonstrating CausalPFN for causal effect estimation:
import numpy as np
import torch
import time
from causalpfn import CATEEstimator, ATEEstimator
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 1. Generate synthetic data
np.random.seed(42)
n, d = 20000, 5
X = np.random.normal(1, 1, size=(n, d)).astype(np.float32)
# Define true causal effects
def true_cate(x):
return np.sin(x[:, 0]) + 0.5 * x[:, 1]
def true_ate():
return np.mean(true_cate(X))
# Generate treatment and outcomes
tau = true_cate(X).astype(np.float32)
T = np.random.binomial(1, p=0.5, size=n).astype(np.float32)
Y0 = X[:, 0] - X[:, 1] + np.random.normal(0, 0.1, size=n).astype(np.float32)
Y1 = Y0 + tau
Y = Y0 * (1 - T) + Y1 * T
# 2. Train/test split
train_idx = np.random.choice(n, size=int(0.7 * n), replace=False)
test_idx = np.setdiff1d(np.arange(n), train_idx)
X_train, X_test = X[train_idx], X[test_idx]
T_train, Y_train = T[train_idx], Y[train_idx]
tau_test = tau[test_idx]
# 3. CATE Estimation
start_time = time.time()
causalpfn_cate = CATEEstimator(
device=device,
verbose=True,
)
causalpfn_cate.fit(X_train, T_train, Y_train)
cate_hat = causalpfn_cate.estimate_cate(X_test)
cate_time = time.time() - start_time
# 4. ATE Estimation
causalpfn_ate = ATEEstimator(
device=device,
verbose=True,
)
causalpfn_ate.fit(X, T, Y)
ate_hat = causalpfn_ate.estimate_ate()
# 5. Evaluation
pehe = np.sqrt(np.mean((cate_hat - tau_test) ** 2))
ate_rel_error = np.abs((ate_hat - true_ate()) / true_ate())
print(f"Results:")
print(f"ATE Relative Error: {ate_rel_error:.4f}")
print(f"PEHE: {pehe:.4f}")
print(f"CATE estimation time per 1000 samples: {cate_time / (len(X) / 1000):.4f}s")
Examples
Explore our notebook collection below. Before running the notebooks, make sure to install the additional dependencies via pip install .[dev].
| Notebook | Description | Features |
|---|---|---|
| Causal Effect Estimation | Compare CausalPFN with baseline methods | CATE/ATE estimation, benchmarking |
| Hillstrom Marketing | Uplift modeling case study | Real-world marketing application |
| Calibration Analysis | Uncertainty quantification demo | Confidence intervals, calibration |
Performance Benchmark
Time vs. Performance. Comparison across 130 causal inference tasks from IHDP, ACIC, and Lalonde. CausalPFN achieves the best average rank (by precision in estimation of heterogeneous effect) while being much faster than other baselines.
Reproducibility
To fully reproduce the paper results, see the REPRODUCE file.
Citation
If you use CausalPFN in your research, please cite our paper:
@misc{causalpfn2025,
title={CausalPFN: Amortized Causal Effect Estimation via In-Context Learning},
author={Vahid Balazadeh and Hamidreza Kamkari and Valentin Thomas and Benson Li and Junwei Ma and Jesse C. Cresswell and Rahul G. Krishnan},
year={2025},
primaryClass={cs.LG},
}
Contributing
We welcome contributions! Please feel free to submit a Pull Request.
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
License
This project is licensed under the Apache-2.0 License - see the LICENSE file for details.
⭐ Star us on GitHub • 🐛 Report Bug • 💡 Request Feature
Made with ❤️ for better causal inference
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file causalpfn-0.1.4.tar.gz.
File metadata
- Download URL: causalpfn-0.1.4.tar.gz
- Upload date:
- Size: 28.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
530e2b7482909ca17d120c4e770d0df03b52960670775984e15d6e8e642601eb
|
|
| MD5 |
a141f2fa14bbc80450ae654881fddfb4
|
|
| BLAKE2b-256 |
ac3a76d5bd01e50931dbdc0184f33f5bd62b30fa5d3849882f656f1d39abee03
|
File details
Details for the file causalpfn-0.1.4-py3-none-any.whl.
File metadata
- Download URL: causalpfn-0.1.4-py3-none-any.whl
- Upload date:
- Size: 27.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
729b8c106cd3b344b346eb69a76236de00bcff991a129bab9ea8eb8d08f1d96a
|
|
| MD5 |
9fad8ffa19277477b1ddf2339665c56e
|
|
| BLAKE2b-256 |
f5e9b07acdc0b9cc5e9c4a47da22c0a51c007f048fa188d8fd5321a4b674249c
|