Interpretable GNN-Based Framework for Drug Discovery and Candidate Screening
Project description
InterGNN — Interpretable GNN-Based Framework for Drug Discovery
An interpretable Graph Neural Network framework combining state-of-the-art molecular property prediction with inherent and post-hoc explainability methods. Designed for drug discovery workflows requiring trust, transparency, and scientific insight.
Architecture
SMILES → Standardize → Featurize → MolecularGNNEncoder ──┐
├─ CrossAttention → TaskHead → Prediction
Protein → ProteinGraphBuilder → TargetGNNEncoder ─────────┘
│
┌───────────────┼───────────────┐
▼ ▼ ▼
PrototypeLayer MotifHead ConceptWhitening
(case-based) (substructure) (axis-aligned)
Key Features
| Feature | Description |
|---|---|
| Molecular Encoder | GINEConv with edge-aware message passing and chirality features |
| Target Encoder | Multi-head GATConv for residue-level protein graphs |
| Cross-Attention Fusion | Atom-residue interaction for drug-target affinity |
| PAGE Prototypes | Case-based classification via learned prototypes |
| MAGE Motifs | Differentiable motif mask generation with Gumbel-sigmoid |
| Concept Whitening | ZCA whitening + axis-aligned concept interpretability |
| CF-GNNExplainer | Counterfactual minimal perturbation explanations |
| T-GNNExplainer | Sufficient subgraph identification |
| CIDER Diagnostics | Causal invariance testing across environments |
Installation
# Clone the repository
git clone https://github.com/your-org/Inter_gnn.git
cd Inter_gnn
# Install with all dependencies
pip install -e ".[vis,dev]"
Requirements
- Python ≥ 3.9
- PyTorch ≥ 2.0
- PyTorch Geometric ≥ 2.4
- RDKit ≥ 2023.03
- NumPy, SciPy, Pandas, scikit-learn, matplotlib
Quick Start
1. Create a Configuration
# config.yaml
data:
dataset_name: tox21
split_method: scaffold
batch_size: 32
detect_cliffs: true
compute_concepts: true
model:
hidden_dim: 256
num_mol_layers: 4
task_type: classification
num_tasks: 12
interpretability:
use_prototypes: true
num_prototypes_per_class: 5
use_motifs: true
num_motifs: 8
use_concept_whitening: true
training:
pretrain_epochs: 50
finetune_epochs: 100
learning_rate: 0.001
2. Train
inter-gnn train --config config.yaml
3. Evaluate
inter-gnn evaluate --config config.yaml --checkpoint checkpoints/finetune_best.pt
4. Generate Explanations
inter-gnn explain --config config.yaml --checkpoint model.pt --smiles "CC(=O)Oc1ccccc1C(=O)O"
5. Dashboard
inter-gnn dashboard --config config.yaml --checkpoint model.pt --output report/
Python API
from inter_gnn.training.config import InterGNNConfig
from inter_gnn.training.trainer import InterGNNTrainer
from inter_gnn.data.datamodule import InterGNNDataModule
# Load config
config = InterGNNConfig.from_yaml("config.yaml")
# Build data
dm = InterGNNDataModule(config)
dm.prepare_data()
dm.setup()
# Train (two-phase: pretrain → finetune)
trainer = InterGNNTrainer(config)
history = trainer.fit(dm.train_dataloader(), dm.val_dataloader())
# Explain a molecule
from inter_gnn.data.featurize import smiles_to_graph
import torch
graph = smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")
batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
output = trainer.model(graph.x, graph.edge_index, graph.edge_attr, batch)
importance = trainer.model.get_node_importance(
graph.x, graph.edge_index, graph.edge_attr, batch
)
Module Overview
inter_gnn/
├── data/ # Data & Preprocessing
│ ├── standardize.py # Molecule standardization (tautomer, charge, stereo)
│ ├── featurize.py # SMILES → molecular graph (~78-dim atom, ~14-dim bond)
│ ├── protein.py # Protein sequence → k-NN / contact graph
│ ├── concepts.py # SMARTS concept library (~30 patterns)
│ ├── cliffs.py # Activity cliff detection
│ ├── splits.py # Scaffold, cold-target, temporal splits
│ ├── datasets.py # 9 benchmark dataset loaders
│ └── datamodule.py # DataModule wrapper
├── models/ # Core Model
│ ├── encoders.py # GINEConv (molecule) + GATConv (protein) encoders
│ ├── attention.py # Cross-attention fusion + bilinear alternative
│ ├── task_heads.py # Classification + regression heads
│ └── core_model.py # Unified InterGNN model
├── interpretability/ # Intrinsic Interpretability
│ ├── prototypes.py # PAGE-inspired prototype layer
│ ├── motifs.py # MAGE-inspired motif generator
│ ├── concept_whitening.py # ZCA whitening + concept alignment
│ └── stability.py # Explanation stability regularizer
├── explainers/ # Post-hoc Explanations
│ ├── cf_explainer.py # CF-GNNExplainer (counterfactual)
│ ├── t_explainer.py # T-GNNExplainer (sufficient subgraph)
│ └── cider.py # CIDER causal invariance diagnostics
├── training/ # Training Pipeline
│ ├── losses.py # Combined multi-objective loss
│ ├── trainer.py # Two-phase trainer (pretrain + finetune)
│ ├── callbacks.py # EarlyStopping, checkpointing, monitoring
│ └── config.py # YAML config with dataclass hierarchy
├── evaluation/ # Evaluation Metrics
│ ├── predictive.py # ROC-AUC, PR-AUC, RMSE, CI, etc.
│ ├── faithfulness.py # Deletion/Insertion AUC, sufficiency/necessity
│ ├── stability_metrics.py # Jaccard stability, cliff consistency
│ ├── chemical_validity.py # Valence checks, SMARTS match rates
│ ├── causal.py # Invariance violation, environment alignment
│ └── statistical.py # Paired bootstrap, randomization tests
├── visualization/ # Visualization Tools
│ ├── molecule_viz.py # Atom/bond saliency rendering
│ ├── prototype_viz.py # Prototype gallery
│ ├── motif_viz.py # Motif activation heatmaps
│ ├── concept_viz.py # Concept activation bars
│ ├── counterfactual_viz.py# Counterfactual edit display
│ └── dashboard.py # HTML batch-export dashboard
└── cli.py # Command-line interface
Supported Datasets
| Dataset | Type | Tasks | Source |
|---|---|---|---|
| MUTAG | Classification | 1 | TUDataset |
| Tox21 | Classification | 12 | MoleculeNet |
| ClinTox | Classification | 2 | MoleculeNet |
| QM9 | Regression | 19 | MoleculeNet |
| Davis | DTA Regression | 1 | TDC |
| KIBA | DTA Regression | 1 | TDC |
| BindingDB | DTA Regression | 1 | TDC |
| SIDER | Classification | 27 | MoleculeNet |
| SynLethDB | Classification | 1 | Custom |
Two-Phase Training
- Pre-training — Trains encoders + task head with prediction loss only
- Joint Fine-tuning — Attaches interpretability modules, trains all losses:
L_pred: Task prediction (BCE/MSE)L_pull/push/div: Prototype lossesL_sparsity/conn: Motif lossesL_align/decorr: Concept whitening lossesL_stability: Explanation stability
Citation
@software{inter_gnn2025,
title={InterGNN: Interpretable Graph Neural Network for Drug Discovery},
year={2025},
}
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file inter_gnn-0.1.0.tar.gz.
File metadata
- Download URL: inter_gnn-0.1.0.tar.gz
- Upload date:
- Size: 69.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
560d7a08f282c14b9f3df2d103e1865bf90edf3f748467228fbe936552a637f0
|
|
| MD5 |
463e8c7e4c2446076b3c939e1be144fc
|
|
| BLAKE2b-256 |
902d55c6d08bf2c9782db2e17ee67d14b23f8bb251d9be1f6cc2ad25d4a0ef28
|
File details
Details for the file inter_gnn-0.1.0-py3-none-any.whl.
File metadata
- Download URL: inter_gnn-0.1.0-py3-none-any.whl
- Upload date:
- Size: 87.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9953c2c8923b35cb07cb3cc2d04faff290983033789587a830d381d2f111fa97
|
|
| MD5 |
adaace7731c967054ac71e6b9df1e9a3
|
|
| BLAKE2b-256 |
e17db2d872972bc389971bc39ff53d78b2124c313f62003a8a738ca3e26aa932
|