Skip to main content

High-performance causal discovery using PCMCI algorithms with JAX acceleration

Project description

JAX-PCMCI

High-Performance Causal Discovery from Time Series using JAX

Python 3.9+ JAX License: MIT

Author's note: Hello, what brings you to this program? If you are here, I would love to hear your thoughts on this library and how you are using it. Just send me an email anytime. If you have any issues with it, please open an issue, or just tell me. I will likely get it fixed somewhat quickly.

Also, side note. A lot of the performance is based off of the parameters. So changes like batch size, tau, precision, or any other parameters can make a huge difference to speed.

JAX-PCMCI is a library for causal discovery from time series data, implementing the PCMCI family of algorithms with GPU/TPU acceleration through JAX. It provides significant speedups over CPU-based implementations while maintaining scientific rigor.

Key Features

  • GPU/TPU Acceleration: Leverages JAX for massive parallelization
  • PCMCI & PCMCI+: Both lagged and contemporaneous causal discovery
  • Multiple Independence Tests:
    • ParCorr: Partial correlation for linear dependencies
    • CMIKnn: k-NN conditional mutual information for nonlinear dependencies
    • GPDCond: Gaussian Process distance correlation for complex nonlinear relationships
  • Parallel Test Execution: Vectorized batch testing with vmap/pmap
  • FDR Correction: Built-in Benjamini-Hochberg and Bonferroni corrections
  • Publication-Ready Visualization: Graph and time series plots

Installation

Basic Installation (CPU)

pip install jax-pcmci

With GPU Support (CUDA 12)

pip install jax-pcmci[gpu]

With TPU Support

pip install jax-pcmci[tpu]

From Source

git clone https://github.com/gpgabriel25/jax-pcmci.git
cd jax-pcmci
pip install -e ".[dev]"

Quick Start

Basic PCMCI Analysis

import jax.numpy as jnp
from jax_pcmci import PCMCI, ParCorr, DataHandler

# Generate sample data (T time points, N variables)
key = jax.random.PRNGKey(42)
T, N = 1000, 5
data = jax.random.normal(key, (T, N))

# Create data handler (automatically normalizes data)
datahandler = DataHandler(data, normalize=True)

# Run PCMCI with partial correlation test
pcmci = PCMCI(datahandler, cond_ind_test=ParCorr())
results = pcmci.run(tau_max=3, pc_alpha=0.05)

# View results
print(results.summary())

# Visualize causal graph
results.plot_graph()

PCMCI+ for Contemporaneous Effects

from jax_pcmci import PCMCIPlus, ParCorr, DataHandler

# PCMCI+ discovers both lagged AND contemporaneous causal links
pcmci_plus = PCMCIPlus(datahandler, cond_ind_test=ParCorr())
results = pcmci_plus.run(tau_max=3)

# Get contemporaneous links specifically
contemp_links = results.get_contemporaneous_links()
for src, tgt, stat, pval in contemp_links:
    print(f"X{src}(t) -> X{tgt}(t): stat={stat:.3f}, p={pval:.4f}")

Nonlinear Causal Discovery

from jax_pcmci import PCMCI, CMIKnn, DataHandler

# Use CMI-kNN for nonlinear relationships
test = CMIKnn(k=10, significance='permutation', n_permutations=200)
pcmci = PCMCI(datahandler, cond_ind_test=test)
results = pcmci.run(tau_max=3)

Available Independence Tests

ParCorr (Partial Correlation)

Best for linear dependencies. Fastest test with analytical p-values.

from jax_pcmci import ParCorr

test = ParCorr(
    significance='analytic',  # or 'permutation'
    alpha=0.05
)

CMIKnn (Conditional Mutual Information with k-NN)

Captures nonlinear dependencies. Uses permutation testing.

from jax_pcmci import CMIKnn

test = CMIKnn(
    k=10,                        # Number of neighbors
    significance='permutation',   # Required for accurate p-values
    n_permutations=500,
    metric='chebyshev'           # or 'euclidean'
)

GPDCond (Gaussian Process Distance Correlation)

Advanced nonlinear test using GP regression residuals.

from jax_pcmci import GPDCond

test = GPDCond(
    kernel='rbf',           # or 'matern32', 'matern52'
    length_scale=1.0,
    significance='permutation'
)

Configuration

Device Selection

from jax_pcmci import set_device, get_device_info

# Check available devices
info = get_device_info()
print(f"GPUs available: {info['gpu_count']}")
print(f"Default backend: {info['default_backend']}")

# Force specific device
set_device('gpu')   # Use GPU
set_device('tpu')   # Use TPU
set_device('cpu')   # Force CPU
set_device('auto')  # Auto-select best

Global Configuration

from jax_pcmci import PCMCIConfig

config = PCMCIConfig(
    precision='float64',       # 'float32' for speed, 'float64' for accuracy
    parallelization='auto',    # 'vmap', 'pmap', or 'sequential'
    random_seed=42,            # For reproducibility
    progress_bar=True,
    verbosity=1                # 0=silent, 1=normal, 2=verbose
)
config.apply()

Working with Results

Accessing Causal Links

results = pcmci.run(tau_max=3)

# All significant links
for src, tgt, tau, stat, pval in results.significant_links:
    print(f"X{src}(t-{tau}) -> X{tgt}(t)")

# Get parents of a specific variable
parents = results.get_parents(variable=0)

# Check specific link
is_causal = results.is_significant(source=1, target=0, lag=2)

Visualization

# Causal graph
fig = results.plot_graph(layout='circular', save_path='graph.png')

# Time series graph (shows temporal structure)
fig = results.plot_time_series_graph(save_path='ts_graph.png')

# Matrix heatmaps
fig = results.plot_matrix(matrix='val', save_path='values.png')
fig = results.plot_matrix(matrix='pval', save_path='pvalues.png')

Export

# To NetworkX
G = results.to_networkx()

# To dictionary (JSON-serializable)
data = results.to_dict()

# Save to file
import json
with open('results.json', 'w') as f:
    json.dump(data, f, indent=2)

Advanced Usage

Custom Independence Test

from jax_pcmci.independence_tests import CondIndTest
import jax.numpy as jnp

class MyCustomTest(CondIndTest):
    name = "MyTest"
    measure = "custom_measure"
    
    def compute_statistic(self, X, Y, Z=None):
        # Your JAX-compatible computation here
        # Must return a scalar JAX array
        pass
    
    def compute_pvalue(self, statistic, n_samples, n_conditions):
        # Compute p-value from statistic
        pass

# Use with PCMCI
pcmci = PCMCI(datahandler, cond_ind_test=MyCustomTest())

Batch Processing for Large Datasets

# For very large datasets, use batch MCI
results = pcmci.run_batch_mci(tau_max=5)

Memory-Efficient Mode

config = PCMCIConfig(
    memory_efficient=True,  # Trades speed for memory
    batch_size=100          # Process tests in batches
)
config.apply()

Algorithm Details

PCMCI

PCMCI (Peter and Clark Momentary Conditional Independence) is a two-phase algorithm:

  1. PC Phase: Iteratively removes spurious parent candidates using conditional independence tests with increasing conditioning set sizes.

  2. MCI Phase: Tests remaining links using Momentary Conditional Independence, conditioning on the parents of both source and target.

PCMCI+

PCMCI+ extends PCMCI to handle contemporaneous (τ=0) causal links:

  1. Skeleton Discovery: Finds undirected edges including contemporaneous
  2. Orientation: Uses time order and v-structure rules to orient edges
  3. MCI Testing: Final tests with full conditioning sets

🧪 Comparison with Tigramite

Feature JAX-PCMCI Tigramite
GPU/TPU Support ✅ Native ❌ CPU only
Parallelization ✅ vmap/pmap ⚠️ Limited
JIT Compilation ✅ Full ❌ No
Independence Tests ParCorr, CMI, GPDC Many
Speed (GPU) 10-100x faster Baseline

📖 References

  1. Runge, J. et al. (2019). "Detecting and quantifying causal associations in large nonlinear time series datasets". Science Advances, 5(11), eaau4996.

  2. Runge, J. (2020). "Discovering contemporaneous and lagged causal relations in autocorrelated nonlinear time series datasets". UAI 2020.

  3. Spirtes, P., Glymour, C., & Scheines, R. (2000). "Causation, prediction, and search". MIT press.

License

MIT License - see LICENSE for details.

📧 Contact

For questions or issues, please open a GitHub issue or contact me at gpgabriel25@gmail.com

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_pcmci-1.1.0.tar.gz (60.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_pcmci-1.1.0-py3-none-any.whl (66.1 kB view details)

Uploaded Python 3

File details

Details for the file jax_pcmci-1.1.0.tar.gz.

File metadata

  • Download URL: jax_pcmci-1.1.0.tar.gz
  • Upload date:
  • Size: 60.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_pcmci-1.1.0.tar.gz
Algorithm Hash digest
SHA256 0ce53250a50c8e2512697e3c1a2685d9438b3e0f792b7bc2f9049e1bf3190bdd
MD5 a558afe0ac23650d2c989ff189ba9f7d
BLAKE2b-256 1889faaf884e57d50efbcd5610a2431c797f7e7dbe35dda8c8c6881fb9c91727

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_pcmci-1.1.0.tar.gz:

Publisher: python-publish.yml on Gpgabriel25/JAX-PCMCI

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_pcmci-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: jax_pcmci-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 66.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_pcmci-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8450666373b1d45b31947e810974981602550758d26e5d74845b7b7835406351
MD5 1fe1bd981948f8b24f887b26b4b6b03b
BLAKE2b-256 1541b00f9fb96bd09ce4412ede9aa4c905a3a9d6578a64f8a14b6cadf19f7d27

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_pcmci-1.1.0-py3-none-any.whl:

Publisher: python-publish.yml on Gpgabriel25/JAX-PCMCI

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page