Skip to main content

A PyTorch implementation of the NOTEARS algorithm for causal discovery.

Project description

Notears PyTorch

A PyTorch implementation of the NOTEARS algorithm (Non-parametric Optimization for Structure Learning) for causal discovery. This package provides a continuous optimization approach to learning DAGs (Directed Acyclic Graphs) from data.

Installation

You can install this package directly from the source:

pip install .

Usage

import numpy as np from notears_pytorch import notears_linear

1. Generate or load data (n_samples x n_features)

n, d = 100, 5 X = np.random.randn(n, d)

2. Run optimization

Returns a binary adjacency matrix where B[i, j] = 1 implies i -> j

adj_matrix = notears_linear(X, lambda1=0.1, w_threshold=0.3)

print("Estimated Adjacency Matrix:") print(adj_matrix)

API

notears_linear(X, lambda1=0.1, ...)

Solves the optimization problem to find the DAG structure.

X: np.ndarray of shape (n, d). The data matrix.

lambda1: float. L1 penalty parameter (sparsity).

rho_init: float. Initial value for the penalty parameter.

w_threshold: float. Edges with weight absolute value below this are pruned.

use_gpu: bool. If True and CUDA is available, computations run on GPU.

Citation

If you use this method, please cite the original paper: Zheng, X., Aragam, B., Ravikumar, P. K., & Xing, E. P. (2018). DAGs with NO TEARS: Continuous Optimization for Structure Learning. Advances in Neural Information Processing Systems.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

notears_pytorch-0.1.0.tar.gz (4.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

notears_pytorch-0.1.0-py3-none-any.whl (4.0 kB view details)

Uploaded Python 3

File details

Details for the file notears_pytorch-0.1.0.tar.gz.

File metadata

  • Download URL: notears_pytorch-0.1.0.tar.gz
  • Upload date:
  • Size: 4.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for notears_pytorch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8d6649d6d9d9fc7cd060a46055300a9e192a7d35fd51dba287e461d1d4f8c7ca
MD5 bfe43c5f91a100fb4c8221406400e466
BLAKE2b-256 350a8ed12a6fabbf734ee8e2346993592867d096cdfa758acf0c344a897fb2b9

See more details on using hashes here.

File details

Details for the file notears_pytorch-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for notears_pytorch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d548873e7bd751d7df64266ea46af99f37b0c60155598af2eb7c58b0beb302cf
MD5 5f4bceee7405939b804fccee0ceb6eec
BLAKE2b-256 2bdccee139319278d4c030bbac915b9c52a6c393c74a695950bcbd78a47cc830

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page