Skip to main content

A lean reimplementation of Zorro: valid, sparse, stable explanations for PyTorch Geometric GNNs, with graph-level and regression support.

Project description

zorrito

A lean, modernized reimplementation of Zorro — valid, sparse, stable explanations for PyTorch Geometric GNNs.

zorrito keeps the greedy core of the paper extends the following

  • Modern Python and Pytorch Geometric support
  • Graph-level explanations, not just node-level
  • Regression support via a tolerance band, with optional one-sided (direction-aware) match functions
  • Configurable noise — whole-row sampling and dataset-wide pools, so perturbed inputs stay on the data manifold (matters a lot for categorical / one-hot features like atom types)

Install

pip install zorrito

zorrito depends on torch>=2.0 and torch-geometric>=2.4. Install those according to the official PyG instructions for your CUDA/PyTorch combination first.

For development:

git clone https://github.com/your-user/zorrito.git
cd zorrito
uv venv --python 3.11
uv pip install -e ".[dev]"
pytest tests/

See DEVELOP.md for the release / publishing checklist.

Quick start

Node classification (Cora-style)

import torch
from zorrito import Zorro

# `model` is any nn.Module that implements forward(x, edge_index)
explainer = Zorro(
    model=model,
    task="node",
    samples=100,
    top_k=10,
    seed=42,
)
explanations = explainer.explain(
    x=data.x,
    edge_index=data.edge_index,
    node_idx=10,
    fidelity_threshold=0.85,
)
expl = explanations[0]
print(expl.fidelity)                       # final RDT-Fidelity reached
print(expl.selected_node_indices())        # indices into the subgraph
print(expl.selected_feature_indices())     # indices into feature columns
print(expl.subgraph_nodes)                 # mapping back to the original graph

Graph classification (MUTAG-style)

from zorrito import Zorro

# `model` is any nn.Module that implements forward(x, edge_index, batch)
explainer = Zorro(
    model=model,
    task="graph",
    noise_mode="row",                # recommended for categorical features
    noise_pool=dataset_atom_pool,    # dataset-wide pool of valid atoms
    samples=100,
    seed=42,
)
explanations = explainer.explain(
    x=graph.x,
    edge_index=graph.edge_index,
    fidelity_threshold=0.95,
)

Regression with a one-sided direction

explainer = Zorro(
    model=model,
    task="graph",
    objective="regression",
    tolerance=0.4,                   # ε for the tolerance band
    direction="down",                # which atoms prevent the prediction
                                     # from FALLING below ref - ε
    noise_mode="row",
    noise_pool=dataset_atom_pool,
    seed=42,
)
explanations = explainer.explain(
    x=molecule.x,
    edge_index=molecule.edge_index,
    fidelity_threshold=0.95,
)

Key concepts

RDT-Fidelity. For a candidate explanation $\mathcal{S} = (V_s, F_s)$:

$$\mathcal{F}(\mathcal{S}) = \mathbb{E}{Z\sim\mathcal{N}}\big[,\mathrm{match}(\Phi(Y\mathcal{S}), \Phi(X)),\big]$$

where $Y_\mathcal{S}$ is the original input with the unselected entries replaced by random draws $Z$ from a noise distribution. zorrito estimates this with samples Monte-Carlo trials.

Match functions (the per-trial yes/no test):

objective direction match condition
classification n/a argmax(Φ(Y_S)) == argmax(Φ(X))
regression both `
regression up Φ(Y_S) ≤ Φ(X) + tolerance
regression down Φ(Y_S) ≥ Φ(X) − tolerance

For regression, direction="up" selects atoms that prevent the prediction from rising, and direction="down" selects atoms that prevent it from falling. This is often more informative than the symmetric tolerance band when the two sides have different chemistry.

Greedy algorithm. Starting from the empty selection, at each iteration zorrito evaluates the top-K candidate nodes and top-K candidate features (the "k" in top_k), adds the single element with the largest fidelity gain, and stops when fidelity exceeds fidelity_threshold. Then optionally re-runs to enumerate disjoint alternative explanations (max_explanations).

Noise distribution. The original paper draws each noise cell independently from its column's empirical distribution. zorrito keeps that as the default for node tasks (continuous features), and switches to whole-row sampling by default for graph tasks (categorical / one-hot features stay valid). The pool the rows are drawn from can be set independently via noise_pool.

Configuration cheat sheet

Zorro(
    model,
    task               = "node" | "graph",
    objective          = "classification" | "regression",
    select             = "both" | "nodes_only" | "features_only",
    direction          = "both" | "up" | "down",     # regression only
    noise_pool         = None | torch.Tensor,        # default: x from explain()
    noise_mode         = "column" | "row",           # default: column (node), row (graph)
    device             = "cpu" | "cuda",
    samples            = 100,
    top_k              = 10,
    tolerance          = 0.0,
    num_hops           = 2,                          # node-task subgraph radius
    seed               = None,
    log                = False,
)
explainer.explain(
    x,
    edge_index,
    node_idx           = None,                       # required for task="node"
    batch              = None,
    fidelity_threshold = 0.85,
    max_explanations   = 1,
)

The return is a list[Explanation], where each Explanation exposes node_mask, feature_mask, fidelity, trace, and (for node tasks) subgraph_nodes. The masks are boolean tensors; use selected_node_indices() / selected_feature_indices() to convert.

Notable departures from the original paper / reference implementation

  • fidelity_threshold=0.85 is the value $\tau$ from the paper. The original code uses tau=0.15 (which is 1 − τ_paper)
  • max_explanations replaces recursion_depth.
  • Explanations are returned as a single Explanation dataclass with boolean masks, instead of nested numpy arrays.
  • The package targets modern Python (>=3.10) and PyG (>=2.4).

Citing the algorithm

The algorithm is from the original Zorro paper:

@article{funke2021zorro,
  title   = {Zorro: Valid, Sparse, and Stable Explanations in Graph Neural Networks},
  author  = {Funke, Thorben and Khosla, Megha and Rathee, Mandeep and Anand, Avishek},
  journal = {arXiv preprint arXiv:2105.08621},
  year    = {2021}
}

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zorrito-0.2.0.tar.gz (18.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zorrito-0.2.0-py3-none-any.whl (12.1 kB view details)

Uploaded Python 3

File details

Details for the file zorrito-0.2.0.tar.gz.

File metadata

  • Download URL: zorrito-0.2.0.tar.gz
  • Upload date:
  • Size: 18.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zorrito-0.2.0.tar.gz
Algorithm Hash digest
SHA256 a1ccd24d48c1ca1b987039ea2cd4a7e611e8201d1f6213c258771b64394f1a5d
MD5 e2a1effa3926b53988413a9a3f122143
BLAKE2b-256 badc857d788c473780ebcb84e9d18e5bf348e970c0ba3958fb31474ccf6f0f93

See more details on using hashes here.

File details

Details for the file zorrito-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: zorrito-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 12.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zorrito-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9604e8c36d0f1aaba1bedd567fd802f6af04e0196ea09b00484e04947ddde260
MD5 1732b47bb022da01efd194d39d5aee9f
BLAKE2b-256 f7229a655f07774430dc6ceb4d25b8344bd64539c6ffa82df3bb080129f4f8b8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page