Skip to main content

Deep Koopman operators for causal discovery

Project description

Kausal: Deep Koopman Operators for Causal Discovery

Kausal is a PyTorch package to perform causal inference in nonlinear, high-dimensional dynamics using deep Koopman operator-theoretic formalism.


Features

Additional features include:

Abstract

Overview of Kausal Causal discovery aims to identify cause-effect mechanisms for better scientific understanding, explainable decision-making, and more accurate modeling. Standard statistical frameworks, such as Granger causality, lack the ability to quantify causal relationships in nonlinear dynamics due to the presence of complex feedback mechanisms, timescale mixing, and nonstationarity. Thus, applying these methods to study causal dynamics in real-world systems, such as the Earth, is a major challenge. Addressing this shortcoming, we leverage deep learning and a Koopman operator-theoretic formalism to present a new class of causal discovery algorithms. Kausal uses deep Koopman operator methods to approximate nonlinear dynamics in a linearized vector space in which traditional causal inference methods such as Granger causality can be more easily applied. Our idealized experiments demonstrate Kausal's superior ability in discovering and characterizing causal signals compared to existing deep learning and non-deep learning state-of-the-art approaches. Finally, the successful identification of major El Niño and La Niña events in observations showcases Kausal's skill to handle real-world applications.

Installation

Kausal is available on PyPi, so installation is as easy as:

pip install kausal

If you use conda, please use the following commands:

conda create --name venv python=3.10
conda activate venv
pip install kausal

Quickstart Guide

Please refer to our tutorial notebooks in the tutorial/ folder for full demonstration.

Causal estimation

The most basic functionality is to perform causal estimation useful for e.g., event detection, relative strength measurements between variables.

import torch
from kausal.koopman import Kausal

# Define cause-effect variables to be tested.
x_cause = torch.randn(3, 1000) # (n_channels, n_timesteps)
x_effect = torch.randn(3, 1000) # (n_channels, n_timesteps)

# Initialize the Kausal object
causal_koopman = Kausal(cause = x_cause, effect = x_effect)

# Evaluate (with e.g., time_shift = 1)
causal_effect, p_values = causal_koopman.evaluate(
    time_shift=1, 
    bootstrap_ratio=0.9, ## Subtrajectory length for uncertainty quantification
    bootstrap_nums=100   ## Number of resampling for uncertainty quantification
)

Causal emulation

Once you fit your Koopman operators under some time shift, you can perform rollouts.

import torch
from kausal.koopman import Kausal

# Define cause-effect variables to be tested.
x_cause = torch.randn(3, 1000) # (n_channels, n_timesteps)
x_effect = torch.randn(3, 1000) # (n_channels, n_timesteps)

# Initialize the Kausal object
causal_koopman = Kausal(cause = x_cause, effect = x_effect)

# Evaluate (with e.g., time_shift = 1)
x_forecast_marginal, x_forecast_joint = causal_koopman.forecast(
    n_train = int(0.8 * 1000), # Number of time samples used for training
    time_shift = 1
)

Causal graph discovery

Ultimately, we can iterate through pairwise combination of variables to deduce their overall causal structures.

import torch
from kausal import Graph

# Define cause-effect variables to be tested.
x = torch.randn(10, 3, 1000) # (n_vars, n_channels, n_timesteps)

# Initialize Graph object
graph_model = Graph()

# Evaluate
graph_model.infer(
    X = x,
    time_shift = 100,
    bootstrap_kwargs = {'bootstrap_ratio': 0.9, 'bootstrap_nums': 30}
)

# Get some results
graph_model.get_adjacency() # Print out graph adjacency
graph_model.print_result()  # Print out p_values, causal measures and its uncertainties

Advanced Guides

Using deep learning

You can use deep learning-based features for the observables.

import torch
from kausal.koopman import Kausal
from kausal.observables import MLPFeatures

# Define cause-effect variables to be tested.
x_cause = torch.randn(3, 1000) # (n_channels, n_timesteps)
x_effect = torch.randn(3, 1000) # (n_channels, n_timesteps)

# Initialize Kausal object (note the extra observables parameters)
causal_koopman = Kausal(
    marginal_observable = MLPFeatures(in_channels=3, hidden_channels=hidden_channels, out_channels=3),
    joint_observable = MLPFeatures(in_channels=6, hidden_channels=hidden_channels, out_channels=3),
    cause = x_cause,
    effect = x_effect,
)

# Fit the observables
marginal_loss_ce, joint_loss_ce = causal_koopman.fit(
    n_train = int(0.8 * 1000), 
    epochs = 500, 
    lr = 1e-2, 
    batch_size = int(0.8 * 1000)
)

# Evaluate (with e.g., time_shift = 1)
causal_effect, p_values = causal_koopman.evaluate(time_shift=1)

Using low-rank

Low-rank estimators are also available e.g., through SVD.

import torch
from kausal.koopman import Kausal
from kausal.regressors import DMD

# Initialize Kausal object
model = Kausal(
    regressor = DMD(svd_rank = 4),
    cause = torch.tensor(...),
    effect = torch.tensor(...)
)

Experimental Results

You can find accompanying code to reproduce the experimental results in the experiments/ folder.

Developer's Guide

We welcome and appreciate any contribution to improve the codebase! You can make a Pull Request or raise an Issue. During development, install the package in the editable format:

git clone https://github.com/juannat7/kausal.git
cd kausal/
pip install -e .

Citation

If you find any of the code and dataset useful, feel free to acknowledge our work through:

@article{nathaniel2025deepkoopmanoperatorframework,
  title={Deep Koopman operator framework for causal discovery in nonlinear dynamical systems},
  author={Juan Nathaniel and Carla Roesch and Jatan Buch and Derek DeSantis and Adam Rupe and Kara Lamb and Pierre Gentine},
  journal={arXiv preprint arXiv:2505.14828},
  year={2025}
}

@article{rupe2024causal,
  title={Causal Discovery in Nonlinear Dynamical Systems using Koopman Operators},
  author={Rupe, Adam and DeSantis, Derek and Bakker, Craig and Kooloth, Parvathi and Lu, Jian},
  journal={arXiv preprint arXiv:2410.10103},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kausal-1.0.1.tar.gz (69.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kausal-1.0.1-py3-none-any.whl (85.3 kB view details)

Uploaded Python 3

File details

Details for the file kausal-1.0.1.tar.gz.

File metadata

  • Download URL: kausal-1.0.1.tar.gz
  • Upload date:
  • Size: 69.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.16

File hashes

Hashes for kausal-1.0.1.tar.gz
Algorithm Hash digest
SHA256 655d9684f08ac2aa38ffa83621e5cc34391cba5e9fd8fb98f44e66e2977d53da
MD5 ab23bf62c3247f30c7bb500483801f63
BLAKE2b-256 3068df00cc69b541a309d4b22b13227c1a2d5e8baddb986c29958160f7e1ef70

See more details on using hashes here.

File details

Details for the file kausal-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: kausal-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 85.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.16

File hashes

Hashes for kausal-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b5b202131ccf905d88dc3383d3f70f6cc75dcef355cd8e78da7642b802fe9dec
MD5 3a6eb54823140c45548314d1313cad42
BLAKE2b-256 1af0a606eaf3283b829a17bba1128f76c26d3b87745572fbc66ef697b677b2d4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page