Skip to main content

A lean reimplementation of Zorro: valid, sparse, stable explanations for PyTorch Geometric GNNs, with graph-level and regression support.

Project description

⭐ Zorrito — Stable, Sparse, and Valid Explanations for Modern PyTorch Geometric

zorrito banner

A lean, modernized reimplementation of Zorro — valid, sparse, stable explanations for PyTorch Geometric GNNs.

zorrito keeps the greedy core idea of the paper and extends it with the following aspects:

  • Modern Python and Pytorch Geometric support!
  • Graph-level explanations, not just node-level
  • Regression support via a tolerance band, with optional one-sided (direction-aware) match functions
  • Structure-only explanations which can be configured as an alternative to feature based explanations.
  • Configurable noise — whole-row sampling and dataset-wide pools, so perturbed inputs stay on the data manifold

📦 Install

pip install zorrito

zorrito depends on torch>=2.0 and torch-geometric>=2.4. Install those according to the official PyG instructions for your CUDA/PyTorch combination first.

For development:

git clone https://github.com/your-user/zorrito.git
cd zorrito
uv venv --python 3.11
uv pip install -e ".[dev]"
pytest tests/

🚀 Quick start

Node classification (Cora-style)

import torch
from zorrito import Zorro

# `model` is any nn.Module that implements forward(x, edge_index)
explainer = Zorro(
    model=model,
    task="node",
    samples=100,
    top_k=10,
    seed=42,
)
explanations = explainer.explain(
    x=data.x,
    edge_index=data.edge_index,
    node_idx=10,
    fidelity_threshold=0.85,
)
expl = explanations[0]
print(expl.fidelity)                       # final RDT-Fidelity reached
print(expl.selected_node_indices())        # indices into the subgraph
print(expl.selected_feature_indices())     # indices into feature columns
print(expl.subgraph_nodes)                 # mapping back to the original graph

Graph classification (MUTAG-style)

from zorrito import Zorro

# `model` is any nn.Module that implements forward(x, edge_index, batch)
explainer = Zorro(
    model=model,
    task="graph",
    noise_mode="row",                # recommended for categorical features
    noise_pool=dataset_atom_pool,    # dataset-wide pool of valid atoms
    samples=100,
    seed=42,
)
explanations = explainer.explain(
    x=graph.x,
    edge_index=graph.edge_index,
    fidelity_threshold=0.95,
)

Regression with a one-sided direction

explainer = Zorro(
    model=model,
    task="graph",
    objective="regression",
    tolerance=0.4,                   # ε for the tolerance band
    direction="down",                # which atoms prevent the prediction
                                     # from FALLING below ref - ε
    noise_mode="row",
    noise_pool=dataset_atom_pool,
    seed=42,
)
explanations = explainer.explain(
    x=molecule.x,
    edge_index=molecule.edge_index,
    fidelity_threshold=0.95,
)

🧠 Key concepts

RDT-Fidelity. For a candidate explanation $\mathcal{S} = (V_s, F_s)$:

$$\mathcal{F}(\mathcal{S}) = \mathbb{E}{Z\sim\mathcal{N}}\big[,\mathrm{match}(\Phi(Y\mathcal{S}), \Phi(X)),\big]$$

where $Y_\mathcal{S}$ is the original input with the unselected entries replaced by random draws $Z$ from a noise distribution. zorrito estimates this with samples Monte-Carlo trials.

Match functions (the per-trial yes/no test):

objective direction match condition
classification n/a argmax(Φ(Y_S)) == argmax(Φ(X))
regression both `
regression up Φ(Y_S) ≤ Φ(X) + tolerance
regression down Φ(Y_S) ≥ Φ(X) − tolerance

For regression, direction="up" selects atoms that prevent the prediction from rising, and direction="down" selects atoms that prevent it from falling. This is often more informative than the symmetric tolerance band when the two sides have different chemistry.

Greedy algorithm. Starting from the empty selection, at each iteration zorrito evaluates the top-K candidate nodes and top-K candidate features (the "k" in top_k), adds the single element with the largest fidelity gain, and stops when fidelity exceeds fidelity_threshold. Then optionally re-runs to enumerate disjoint alternative explanations (max_explanations).

Noise distribution. The original paper draws each noise cell independently from its column's empirical distribution. zorrito keeps that as the default for node tasks (continuous features), and switches to whole-row sampling by default for graph tasks (categorical / one-hot features stay valid). The pool the rows are drawn from can be set independently via noise_pool.

⚙️ Configuration cheat sheet

Zorro(
    model,
    task               = "node" | "graph",
    objective          = "classification" | "regression",
    select             = "both" | "nodes_only" | "features_only",
    direction          = "both" | "up" | "down",     # regression only
    noise_pool         = None | torch.Tensor,        # default: x from explain()
    noise_mode         = "column" | "row",           # default: column (node), row (graph)
    device             = "cpu" | "cuda",
    samples            = 100,
    top_k              = 10,
    tolerance          = 0.0,
    num_hops           = 2,                          # node-task subgraph radius
    seed               = None,
    log                = False,
)
explainer.explain(
    x,
    edge_index,
    node_idx           = None,                       # required for task="node"
    batch              = None,
    fidelity_threshold = 0.85,
    max_explanations   = 1,
)

The return is a list[Explanation], where each Explanation exposes node_mask, feature_mask, fidelity, trace, and (for node tasks) subgraph_nodes. The masks are boolean tensors; use selected_node_indices() / selected_feature_indices() to convert.

🔀 Notable departures from the original paper / reference implementation

  • fidelity_threshold=0.85 is the value $\tau$ from the paper. The original code uses tau=0.15 (which is 1 − τ_paper)
  • max_explanations replaces recursion_depth.
  • Explanations are returned as a single Explanation dataclass with boolean masks, instead of nested numpy arrays.
  • The package targets modern Python (>=3.10) and PyG (>=2.4).

📚 Citing the algorithm

The algorithm is from the original Zorro paper:

@article{funke2021zorro,
  title   = {Zorro: Valid, Sparse, and Stable Explanations in Graph Neural Networks},
  author  = {Funke, Thorben and Khosla, Megha and Rathee, Mandeep and Anand, Avishek},
  journal = {arXiv preprint arXiv:2105.08621},
  year    = {2021}
}

📄 License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zorrito-0.3.0.tar.gz (33.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zorrito-0.3.0-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file zorrito-0.3.0.tar.gz.

File metadata

  • Download URL: zorrito-0.3.0.tar.gz
  • Upload date:
  • Size: 33.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zorrito-0.3.0.tar.gz
Algorithm Hash digest
SHA256 d6f8d68cb8d0ffc8b0cd9802729145f33b30e2b8edd4a9b478a89d68b7bee6e2
MD5 bfb7abea33db0b04818ee676bda4ae9e
BLAKE2b-256 bc4e7d3030aea84a8898fa5be90d0ec2920de9e19b73272e00c821fccb95d3bd

See more details on using hashes here.

File details

Details for the file zorrito-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: zorrito-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 21.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zorrito-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ab1b1f9967c549493f467f21c4dd271068e33ccca720a5d079da1e31675962f3
MD5 4883f36b3fb493d5060063d390e8d1af
BLAKE2b-256 0ff3782fec4b13425bb7e34d2f7a545bdfde4d7e61e1321313ea02513d8baa27

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page