Skip to main content

Collection of methods and tools to learn causal graphs from grouped data.

Project description

gRESIT

view - Documentation

Made with Python PyPI - maintained tests - passing License: AGPL v3 Code style: ruff

This repo aims at learning and representing causal graphs based on grouped data. Theoretical details are presented in the paper

@misc{goebler2025,
  title={Nonlinear Causal Discovery for Grouped Data},
  author={Konstantin G\"obler and Tobias Windisch and Mathias Drton},
  year={2025},
  eprint={2506.05120},
  archivePrefix={arXiv},
  primaryClass={stat.ML},
  url={<https://arxiv.org/abs/2506.05120}>,
}

Authors

Maintainer: Martin Roth (Bosch)

Table of contents

Documentation

The documentation can be found here.

How to install

The package can be installed with

pip install gresit

How to build

Using the Makefile the package can be installed in an editable way like this:

make sync-venv

To use the pre-commit hooks, one has to enable them in the venv, by

pre-commit install

Then these hooks are excecuted before every commit. You can run the hooks for all files also separately

pre-commit run --all-files

or to disable the pip-compile hook, which takes some time

SKIP=pip-compile pre-commit run --all-files

or equivalent

make pre-commit

How to use

Consider the following example. We refer to the documentation for more detailed information.

Generating Synthetic Data

We first generate synthetic data using an Erdős–Rényi random graph model. Each group of variables is defined with a specified size and edge density.

from gresit.synthetic_data import GenERData

data_gen = GenERData(
    number_of_nodes=10,
    group_size=2,
    edge_density=0.2,
)

data_dict, _ = data_gen.generate_data(num_samples=1000)

The output data_dict is a dictionary where each key corresponds to a group, and the values are the observed samples.

Fitting a Graph Model

We now fit a gRESIT model using Multioutcome_MLP as the regressor and HSIC as independence test.

from gresit.group_resit import GroupResit
from gresit.independence_tests import HSIC
from gresit.torch_models import Multioutcome_MLP

model = GroupResit(
    regressor=Multioutcome_MLP(),
    test=HSIC,
    pruning_method="murgs",
)
learned_dag = model.learn_graph(data_dict=data_dict)

# Show the learned graph:
learned_dag.show()
# or show interactive mode:
model.show_interactive()

Accessing the Learned Graph

The learned adjacency matrix representing the estimated group-level graph and a causal ordering can be accessed via:

model.adjacency_matrix
model.causal_ordering

How to test

In general we use pytest and the test suite can be executed locally via

python -m pytest

Github Actions

Documentation with mkdocs

We use mkdocs for building the documentation, this is the corresponding workflow.

Automated issue workflow

With this workflow newly created issues are automatically added to our MFD2 project.

Pre-commit

With this workflow the pre-commit rules, specified in .pre-commit-config.yaml, are executed.

To use pre-commit locally, please use

pre-commit install

Testing

With this workflow the tests are executed.

Third-Party Licenses

Runtime dependencies

Name License Type
numpy BSD-3-Clause License Dependency
pandas BSD-3-Clause License Dependency
scikit-learn BSD-3-Clause License Dependency
statsmodels BSD-3-Clause License Dependency
plotly MIT License Dependency
xgboost Apache License 2.0 Dependency
torch BSD-3-Clause License Dependency
seaborn BSD-3-Clause License Dependency
pyspark Apache License 2.0 Dependency
scikit-misc BSD-3-Clause License Dependency
gadjid MIT License Dependency
tqdm MIT License Dependency
dcor MIT License Dependency
llvmlite BSD-2-Clause License Dependency
causal-learn MIT License Dependency
gcastle Apache License 2.0 Dependency
gpytorch MIT License Dependency

Development dependency

Name License Type
mike BSD-3-Clause License Optional
mkdocs BSD-2-Clause License Optional
mkdocs-material MIT License Optional
mkdocstrings ISC License Optional
pip-licenses MIT License Optional
pip-tools BSD-3-Clause License Optional
pre-commit MIT License Optional
pytest MIT License Optional
pytest-cov MIT License Optional
ruff MIT License Optional
uv MIT License Optional

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gresit-1.0.0.tar.gz (81.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gresit-1.0.0-py3-none-any.whl (81.1 kB view details)

Uploaded Python 3

File details

Details for the file gresit-1.0.0.tar.gz.

File metadata

  • Download URL: gresit-1.0.0.tar.gz
  • Upload date:
  • Size: 81.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.12

File hashes

Hashes for gresit-1.0.0.tar.gz
Algorithm Hash digest
SHA256 6b849f8dfc1e1c88a6b041b7a0048f27f360d09e588ba650828074da16654070
MD5 96558c1a37164d7dac39b6a98b1b0ee8
BLAKE2b-256 c75ae7f9cb0c0b0f7fb26059f55e2150cc322854b9a81e910a2fb041a81f3c52

See more details on using hashes here.

File details

Details for the file gresit-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: gresit-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 81.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.12

File hashes

Hashes for gresit-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c0f110a4ccc2a515ba29542c63058feb63adf010467f476c6d6f2d662656f3a1
MD5 f1352b382a0f834500b532315cfe6b6b
BLAKE2b-256 cac7cb120ea081e501f09ac25b6abb867f4ca8c16c7cb32a05d1e1f51bcc3738

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page