Skip to main content

Interpretable GNN-Based Framework for Drug Discovery and Candidate Screening

Project description

InterGNN — Interpretable GNN-Based Framework for Drug Discovery

Python 3.9+ PyTorch License: MIT

An interpretable Graph Neural Network framework combining state-of-the-art molecular property prediction with inherent and post-hoc explainability methods. Designed for drug discovery workflows requiring trust, transparency, and scientific insight.


Architecture

SMILES → Standardize → Featurize → MolecularGNNEncoder ──┐
                                                          ├─ CrossAttention → TaskHead → Prediction
Protein → ProteinGraphBuilder → TargetGNNEncoder ─────────┘
                                    │
                    ┌───────────────┼───────────────┐
                    ▼               ▼               ▼
              PrototypeLayer   MotifHead    ConceptWhitening
              (case-based)    (substructure)  (axis-aligned)

Key Features

Feature Description
Molecular Encoder GINEConv with edge-aware message passing and chirality features
Target Encoder Multi-head GATConv for residue-level protein graphs
Cross-Attention Fusion Atom-residue interaction for drug-target affinity
PAGE Prototypes Case-based classification via learned prototypes
MAGE Motifs Differentiable motif mask generation with Gumbel-sigmoid
Concept Whitening ZCA whitening + axis-aligned concept interpretability
CF-GNNExplainer Counterfactual minimal perturbation explanations
T-GNNExplainer Sufficient subgraph identification
CIDER Diagnostics Causal invariance testing across environments

Installation

# Clone the repository
git clone https://github.com/your-org/Inter_gnn.git
cd Inter_gnn

# Install with all dependencies
pip install -e ".[vis,dev]"

Requirements

  • Python ≥ 3.9
  • PyTorch ≥ 2.0
  • PyTorch Geometric ≥ 2.4
  • RDKit ≥ 2023.03
  • NumPy, SciPy, Pandas, scikit-learn, matplotlib

Quick Start

1. Create a Configuration

# config.yaml
data:
  dataset_name: tox21
  split_method: scaffold
  batch_size: 32
  detect_cliffs: true
  compute_concepts: true

model:
  hidden_dim: 256
  num_mol_layers: 4
  task_type: classification
  num_tasks: 12

interpretability:
  use_prototypes: true
  num_prototypes_per_class: 5
  use_motifs: true
  num_motifs: 8
  use_concept_whitening: true

training:
  pretrain_epochs: 50
  finetune_epochs: 100
  learning_rate: 0.001

2. Train

inter-gnn train --config config.yaml

3. Evaluate

inter-gnn evaluate --config config.yaml --checkpoint checkpoints/finetune_best.pt

4. Generate Explanations

inter-gnn explain --config config.yaml --checkpoint model.pt --smiles "CC(=O)Oc1ccccc1C(=O)O"

5. Dashboard

inter-gnn dashboard --config config.yaml --checkpoint model.pt --output report/

Python API

from inter_gnn.training.config import InterGNNConfig
from inter_gnn.training.trainer import InterGNNTrainer
from inter_gnn.data.datamodule import InterGNNDataModule

# Load config
config = InterGNNConfig.from_yaml("config.yaml")

# Build data
dm = InterGNNDataModule(config)
dm.prepare_data()
dm.setup()

# Train (two-phase: pretrain → finetune)
trainer = InterGNNTrainer(config)
history = trainer.fit(dm.train_dataloader(), dm.val_dataloader())

# Explain a molecule
from inter_gnn.data.featurize import smiles_to_graph
import torch

graph = smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")
batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
output = trainer.model(graph.x, graph.edge_index, graph.edge_attr, batch)

importance = trainer.model.get_node_importance(
    graph.x, graph.edge_index, graph.edge_attr, batch
)

Module Overview

inter_gnn/
├── data/                    # Data & Preprocessing
│   ├── standardize.py       #   Molecule standardization (tautomer, charge, stereo)
│   ├── featurize.py         #   SMILES → molecular graph (~78-dim atom, ~14-dim bond)
│   ├── protein.py           #   Protein sequence → k-NN / contact graph
│   ├── concepts.py          #   SMARTS concept library (~30 patterns)
│   ├── cliffs.py            #   Activity cliff detection
│   ├── splits.py            #   Scaffold, cold-target, temporal splits
│   ├── datasets.py          #   9 benchmark dataset loaders
│   └── datamodule.py        #   DataModule wrapper
├── models/                  # Core Model
│   ├── encoders.py          #   GINEConv (molecule) + GATConv (protein) encoders
│   ├── attention.py         #   Cross-attention fusion + bilinear alternative
│   ├── task_heads.py        #   Classification + regression heads
│   └── core_model.py        #   Unified InterGNN model
├── interpretability/        # Intrinsic Interpretability
│   ├── prototypes.py        #   PAGE-inspired prototype layer
│   ├── motifs.py            #   MAGE-inspired motif generator
│   ├── concept_whitening.py #   ZCA whitening + concept alignment
│   └── stability.py         #   Explanation stability regularizer
├── explainers/              # Post-hoc Explanations
│   ├── cf_explainer.py      #   CF-GNNExplainer (counterfactual)
│   ├── t_explainer.py       #   T-GNNExplainer (sufficient subgraph)
│   └── cider.py             #   CIDER causal invariance diagnostics
├── training/                # Training Pipeline
│   ├── losses.py            #   Combined multi-objective loss
│   ├── trainer.py           #   Two-phase trainer (pretrain + finetune)
│   ├── callbacks.py         #   EarlyStopping, checkpointing, monitoring
│   └── config.py            #   YAML config with dataclass hierarchy
├── evaluation/              # Evaluation Metrics
│   ├── predictive.py        #   ROC-AUC, PR-AUC, RMSE, CI, etc.
│   ├── faithfulness.py      #   Deletion/Insertion AUC, sufficiency/necessity
│   ├── stability_metrics.py #   Jaccard stability, cliff consistency
│   ├── chemical_validity.py #   Valence checks, SMARTS match rates
│   ├── causal.py            #   Invariance violation, environment alignment
│   └── statistical.py       #   Paired bootstrap, randomization tests
├── visualization/           # Visualization Tools
│   ├── molecule_viz.py      #   Atom/bond saliency rendering
│   ├── prototype_viz.py     #   Prototype gallery
│   ├── motif_viz.py         #   Motif activation heatmaps
│   ├── concept_viz.py       #   Concept activation bars
│   ├── counterfactual_viz.py#   Counterfactual edit display
│   └── dashboard.py         #   HTML batch-export dashboard
└── cli.py                   # Command-line interface

Supported Datasets

Dataset Type Tasks Source
MUTAG Classification 1 TUDataset
Tox21 Classification 12 MoleculeNet
ClinTox Classification 2 MoleculeNet
QM9 Regression 19 MoleculeNet
Davis DTA Regression 1 TDC
KIBA DTA Regression 1 TDC
BindingDB DTA Regression 1 TDC
SIDER Classification 27 MoleculeNet
SynLethDB Classification 1 Custom

Two-Phase Training

  1. Pre-training — Trains encoders + task head with prediction loss only
  2. Joint Fine-tuning — Attaches interpretability modules, trains all losses:
    • L_pred: Task prediction (BCE/MSE)
    • L_pull/push/div: Prototype losses
    • L_sparsity/conn: Motif losses
    • L_align/decorr: Concept whitening losses
    • L_stability: Explanation stability

Citation

@software{inter_gnn2025,
  title={InterGNN: Interpretable Graph Neural Network for Drug Discovery},
  year={2025},
}

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

inter_gnn-0.1.0.tar.gz (69.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

inter_gnn-0.1.0-py3-none-any.whl (87.6 kB view details)

Uploaded Python 3

File details

Details for the file inter_gnn-0.1.0.tar.gz.

File metadata

  • Download URL: inter_gnn-0.1.0.tar.gz
  • Upload date:
  • Size: 69.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.3

File hashes

Hashes for inter_gnn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 560d7a08f282c14b9f3df2d103e1865bf90edf3f748467228fbe936552a637f0
MD5 463e8c7e4c2446076b3c939e1be144fc
BLAKE2b-256 902d55c6d08bf2c9782db2e17ee67d14b23f8bb251d9be1f6cc2ad25d4a0ef28

See more details on using hashes here.

File details

Details for the file inter_gnn-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: inter_gnn-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 87.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.3

File hashes

Hashes for inter_gnn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9953c2c8923b35cb07cb3cc2d04faff290983033789587a830d381d2f111fa97
MD5 adaace7731c967054ac71e6b9df1e9a3
BLAKE2b-256 e17db2d872972bc389971bc39ff53d78b2124c313f62003a8a738ca3e26aa932

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page