Collection of methods and tools to learn causal graphs from grouped data.
Project description
gRESIT
This repo aims at learning and representing causal graphs based on grouped data. Theoretical details are presented in the paper
@misc{goebler2025,
title={Nonlinear Causal Discovery for Grouped Data},
author={Konstantin G\"obler and Tobias Windisch and Mathias Drton},
year={2025},
eprint={2506.05120},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={<https://arxiv.org/abs/2506.05120}>,
}
Authors
Maintainer: Martin Roth (Bosch)
Table of contents
Documentation
The documentation can be found here.
How to install
The package can be installed with
pip install gresit
How to build
Using the Makefile the package can be installed in an editable way like this:
make sync-venv
To use the pre-commit hooks, one has to enable them in the venv, by
pre-commit install
Then these hooks are excecuted before every commit. You can run the hooks for all files also separately
pre-commit run --all-files
or to disable the pip-compile hook, which takes some time
SKIP=pip-compile pre-commit run --all-files
or equivalent
make pre-commit
How to use
Consider the following example. We refer to the documentation for more detailed information.
Generating Synthetic Data
We first generate synthetic data using an Erdős–Rényi random graph model. Each group of variables is defined with a specified size and edge density.
from gresit.synthetic_data import GenERData
data_gen = GenERData(
number_of_nodes=10,
group_size=2,
edge_density=0.2,
)
data_dict, _ = data_gen.generate_data(num_samples=1000)
The output data_dict is a dictionary where each key corresponds to a group, and the values are the observed samples.
Fitting a Graph Model
We now fit a gRESIT model using Multioutcome_MLP as the regressor and HSIC as independence test.
from gresit.group_resit import GroupResit
from gresit.independence_tests import HSIC
from gresit.torch_models import Multioutcome_MLP
model = GroupResit(
regressor=Multioutcome_MLP(),
test=HSIC,
pruning_method="murgs",
)
learned_dag = model.learn_graph(data_dict=data_dict)
# Show the learned graph:
learned_dag.show()
# or show interactive mode:
model.show_interactive()
Accessing the Learned Graph
The learned adjacency matrix representing the estimated group-level graph and a causal ordering can be accessed via:
model.adjacency_matrix
model.causal_ordering
How to test
In general we use pytest and the test suite can be executed locally via
python -m pytest
Github Actions
Documentation with mkdocs
We use mkdocs for building the documentation, this is the corresponding workflow.
Automated issue workflow
With this workflow newly created issues are automatically added to our MFD2 project.
Pre-commit
With this workflow the pre-commit rules, specified in .pre-commit-config.yaml, are executed.
To use pre-commit locally, please use
pre-commit install
Testing
With this workflow the tests are executed.
Third-Party Licenses
Runtime dependencies
| Name | License | Type |
|---|---|---|
| numpy | BSD-3-Clause License | Dependency |
| pandas | BSD-3-Clause License | Dependency |
| scikit-learn | BSD-3-Clause License | Dependency |
| statsmodels | BSD-3-Clause License | Dependency |
| plotly | MIT License | Dependency |
| xgboost | Apache License 2.0 | Dependency |
| torch | BSD-3-Clause License | Dependency |
| seaborn | BSD-3-Clause License | Dependency |
| pyspark | Apache License 2.0 | Dependency |
| scikit-misc | BSD-3-Clause License | Dependency |
| gadjid | MIT License | Dependency |
| tqdm | MIT License | Dependency |
| dcor | MIT License | Dependency |
| llvmlite | BSD-2-Clause License | Dependency |
| causal-learn | MIT License | Dependency |
| gcastle | Apache License 2.0 | Dependency |
| gpytorch | MIT License | Dependency |
Development dependency
| Name | License | Type |
|---|---|---|
| mike | BSD-3-Clause License | Optional |
| mkdocs | BSD-2-Clause License | Optional |
| mkdocs-material | MIT License | Optional |
| mkdocstrings | ISC License | Optional |
| pip-licenses | MIT License | Optional |
| pip-tools | BSD-3-Clause License | Optional |
| pre-commit | MIT License | Optional |
| pytest | MIT License | Optional |
| pytest-cov | MIT License | Optional |
| ruff | MIT License | Optional |
| uv | MIT License | Optional |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gresit-1.0.0.tar.gz.
File metadata
- Download URL: gresit-1.0.0.tar.gz
- Upload date:
- Size: 81.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6b849f8dfc1e1c88a6b041b7a0048f27f360d09e588ba650828074da16654070
|
|
| MD5 |
96558c1a37164d7dac39b6a98b1b0ee8
|
|
| BLAKE2b-256 |
c75ae7f9cb0c0b0f7fb26059f55e2150cc322854b9a81e910a2fb041a81f3c52
|
File details
Details for the file gresit-1.0.0-py3-none-any.whl.
File metadata
- Download URL: gresit-1.0.0-py3-none-any.whl
- Upload date:
- Size: 81.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c0f110a4ccc2a515ba29542c63058feb63adf010467f476c6d6f2d662656f3a1
|
|
| MD5 |
f1352b382a0f834500b532315cfe6b6b
|
|
| BLAKE2b-256 |
cac7cb120ea081e501f09ac25b6abb867f4ca8c16c7cb32a05d1e1f51bcc3738
|