Skip to main content

Smoothing the Landscape: Causal Structure Learning via Diffusion Denoising Objectives

Project description

Causal Structure Learning via Diffusion Denoising Objectives

Understanding causal dependencies in observational data is critical for informing decision-making. These relationships are often modeled as Bayesian Networks (BNs) and Directed Acyclic Graphs (DAGs). Existing methods, such as NOTEARS and DAG-GNN, often face issues with scalability and stability in high-dimensional data, especially when there is a feature-sample imbalance. Here, we show that the denoising score matching objective of diffusion models could smooth the gradients for faster, more stable convergence. We also propose an adaptive k-hop acyclicity constraint that improves runtime over existing solutions that require matrix inversion. We name this framework Denoising Diffusion Causal Discovery (DDCD). Unlike generative diffusion models, DDCD utilizes the reverse denoising process to infer a parameterized causal structure rather than to generate data. We demonstrate the competitive performance of DDCDs on synthetic benchmarking data. We also show that our methods are practically useful by conducting qualitative analyses on two real-world examples.

Get started

Installation

pip install ddcd

Example

import ddcd
from castle.datasets import IIDSimulation, DAG

# an unwanted behavior from castle
torch.set_default_dtype(torch.float)

# Generating synthetic data 
dag_adj = DAG.scale_free(
    n_nodes = 100, n_edges = 1000,
    weight_range = (0.5, 1.5), seed=42
)

X = IIDSimulation(
    W=dag_adj, 
    n=2000, method='linear', 
    sem_type='gauss', noise_scale=1
).X 

# Training
model = ddcd.DDCD_Linear_Trainer(X, device='cuda')
model.train(5000)

w = model.get_adj()
A = (np.abs(w) > 0.3) 

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ddcd-0.0.1.tar.gz (24.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ddcd-0.0.1-py2.py3-none-any.whl (39.7 kB view details)

Uploaded Python 2Python 3

File details

Details for the file ddcd-0.0.1.tar.gz.

File metadata

  • Download URL: ddcd-0.0.1.tar.gz
  • Upload date:
  • Size: 24.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for ddcd-0.0.1.tar.gz
Algorithm Hash digest
SHA256 78065fc7f6b614918303b53577515673454967ab8d0ba2fab78971c5e58f4ffd
MD5 0c01806dd1c3d39b741c6ff6ef8b05b6
BLAKE2b-256 eb04d13b2595a4431258945a4f946b6c7cff4f99ad00905e0625688576309861

See more details on using hashes here.

File details

Details for the file ddcd-0.0.1-py2.py3-none-any.whl.

File metadata

  • Download URL: ddcd-0.0.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 39.7 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for ddcd-0.0.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 5f1ef7bc00898a87ea2ac5558d6ca35098ee56a42e35ee0f1ee07d60c6896af6
MD5 ef0b24efe9c0092900bd9d172b1cda0e
BLAKE2b-256 699d99c6df5ec5bc4bfb99c33991a57cf0b83c4d57fe6a65a0c931fa7b8cb155

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page