Skip to main content

A robust SHAP explainer wrapper for PyTorch Geometric models.

Project description

PyPI version

PyG-Captum-SHAP

A robust wrapper bridging the Euclidean sampling mechanisms of Captum with the non-Euclidean batching of PyTorch Geometric (PyG).

Designed specifically for complex molecular QSAR modelling, this library resolves the dimensional mismatch errors (e.g., Expected size 2, got 10) that occur when applying Captum to Graph Neural Networks, whilst enabling simultaneous attribution across multiple graph modalities.

Beyond Native Explainability: The Multi-Modal Bottleneck

While PyTorch Geometric provides native explainability utilities such as to_captum_model and the torch_geometric.explain module, these tools are architecturally constrained to simple spatial topologies.

Native PyG explainers rely on hardcoded mask types (mask_type='node' or 'edge') and automatically relegate any auxiliary graph-level tensors to additional_forward_args. Because Captum strictly computes gradients only for tensors passed within its primary inputs tuple, native PyG utilities render high-dimensional global descriptors (e.g., MolFormer embeddings, RDKit descriptors, or topological signatures) mathematically invisible to the attribution algorithm.

PyG-Captum-SHAP resolves this fundamental limitation. It dynamically packs nodes, edges, and global features into the primary attribution tuple, whilst exclusively shielding the edge_index dictionary. This bypasses the native PyG constraints, enabling true multi-modal SHAP extraction across advanced neural architectures.

Key Features

  1. Dictionary-Shielded Wrapper: Protects structural tensors (edge_index) from corruption during Captum's internal feature perturbation and sampling phases.
  2. Multi-Input Support: Generates mathematically consistent attributions for Nodes (Atoms), Edges (Bonds), and Global Molecular Features simultaneously.
  3. Automatic Reconstruction: Performs on-the-fly reconstruction of block-diagonal graph batches for Captum's internal forward passes.

Installation

pip install pyg-captum-shap

Quick Start

from pyg_captum_shap import compute_shap_values

# Extract attributions for a specific molecule and task
results = compute_shap_values(
    model=your_trained_model,
    target_graph=molecule_graph_data,
    target_task=0,
    n_samples=25
)

# Access node, edge, and global SHAP values
node_importance = results['nodes']          # Shape: [N, F]
edge_importance = results.get('edges')      # Optional
global_importance = results.get('global')   # Optional

# node_attributions now contains the importance score for every atom in the graph

License

Distributed under the MIT License. Built on top of the Captum library by PyTorch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyg_captum_shap-0.1.6.tar.gz (5.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyg_captum_shap-0.1.6-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file pyg_captum_shap-0.1.6.tar.gz.

File metadata

  • Download URL: pyg_captum_shap-0.1.6.tar.gz
  • Upload date:
  • Size: 5.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for pyg_captum_shap-0.1.6.tar.gz
Algorithm Hash digest
SHA256 ccb7bec8dadf999beab94b53f80672fb69f9a505d4a3066a9e5fa439dbf66a5e
MD5 0ee2ce0f1f1b1dd5b18fa614ffe04473
BLAKE2b-256 b68e823b3912f92202e91b346e3e5d90a93e38a87da74171b07772825a58690e

See more details on using hashes here.

File details

Details for the file pyg_captum_shap-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for pyg_captum_shap-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 ec0cbd5b5bea9328d2aa2539bae66e9821a3b3d1a22959b31e2f9d60eb110587
MD5 4d4faef444ef598cea5ed25090debd0c
BLAKE2b-256 0202f48cef89cd4dc0f115d466c66fad1d28843fbb9139c4a74bf6da88bfd41c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page