Graph-based Reconstruction of Artificial Intercellular Links for Cardiac Spatial Transcriptomics
Project description
GRAIL-Heart: Graph-based Reconstruction of Artificial Intercellular Links
A Graph Neural Network framework for analyzing cell-cell communication in cardiac spatial transcriptomics data, featuring both forward and inverse modelling capabilities.
Live Demo
Explore ligand-receptor interaction networks across all six cardiac regions with our interactive web application.
Overview
GRAIL-Heart is a deep learning model designed to discover and analyze ligand-receptor (L-R) interactions in spatial transcriptomics datasets, with a focus on cardiac tissue. The framework integrates:
- Gene expression encoding with neural networks
- Spatial information through positional embeddings
- Graph attention mechanisms for neighborhood analysis
- Multi-task learning for simultaneous prediction of L-R interactions, gene expression reconstruction, and cell type classification
- Inverse modelling for inferring causal L-R signals that drive cell differentiation
Forward vs Inverse Modelling
| Modelling Type | Input | Output | Question Answered |
|---|---|---|---|
| Forward | Expression + Spatial | L-R predictions | "Which L-R interactions are active?" |
| Inverse | Observed phenotype | Causal L-R signals | "What signals drove this differentiation?" |
The model is trained on the Heart Cell Atlas v2, comprising spatial transcriptomics data from six distinct cardiac regions (Apex, Left Atrium, Left Ventricle, Right Atrium, Right Ventricle, and Septum).
Key Features
- Multi-task learning framework balancing L-R prediction, reconstruction, and classification
- Edge-type aware Graph Attention Networks with spatial and L-R edge types
- OmniPath L-R database integration (22,000+ curated pairs from CellPhoneDB, CellChat, ICELLNET)
- Leave-One-Region-Out (LORO) cross-validation for robust generalization assessment
- Inverse modelling framework for causal L-R inference
- Mechanosensitive pathway analysis (YAP/TAZ, Integrin-FAK, Piezo, TGF-β)
- Comprehensive spatial visualization of cell-cell communication networks
- Mixed precision training for efficient GPU utilization
- Complete inference pipeline with cross-region analysis
Cross-Validation Results
GRAIL-Heart was evaluated using 6-fold Leave-One-Region-Out cross-validation:
| Metric | Mean | ± Std | Interpretation |
|---|---|---|---|
| Reconstruction R² | 0.885 | 0.114 | Excellent gene expression reconstruction |
| Pearson Correlation | 0.991 | 0.005 | Near-perfect correlation |
| L-R AUROC | 0.786 | 0.202 | Good L-R prediction (↑4.6% with inverse modelling) |
| L-R AUPRC | 0.974 | 0.030 | Excellent precision-recall |
| Accuracy | 0.927 | 0.059 | Very high classification accuracy |
| F1 Score | 0.959 | 0.035 | Excellent balance |
Per-Region Performance
| Region | R² | AUROC | AUPRC | Top Causal L-R |
|---|---|---|---|---|
| RV (Right Ventricle) | 0.968 | 0.984 | 0.999 | TIMP1→MMP2 (1.844) |
| LA (Left Atrium) | 0.960 | 0.938 | 0.993 | SERPING1→C1S (1.834) |
| AX (Apex) | 0.965 | 0.888 | 0.979 | TIMP1→MMP2 (1.869) |
| LV (Left Ventricle) | 0.969 | 0.861 | 0.962 | CFD→C3 (1.857) |
| RA (Right Atrium) | 0.727 | 0.408 | 0.914 | TIMP2→MMP2 (1.831) |
| SP (Septum) | 0.720 | 0.634 | 0.997 | THBS1→FN1 (1.818) |
Project Structure
GRAIL-Heart/
├── src/grail_heart/ # Main package
│ ├── data/ # Data loading and preprocessing
│ │ ├── datasets.py # Spatial transcriptomics dataset loader
│ │ ├── graph_builder.py # Spatial graph construction (kNN, radius, Delaunay)
│ │ ├── lr_database.py # Standard L-R database
│ │ ├── expanded_lr_database.py # Extended database with 500+ pairs
│ │ └── cellchat_database.py # OmniPath integration (22,000+ pairs)
│ ├── models/ # Neural network architectures
│ │ ├── encoders.py # Gene and spatial encoders
│ │ ├── gat_layers.py # Graph Attention layers with edge type awareness
│ │ ├── grail_heart.py # Main GRAIL-Heart model (forward + inverse)
│ │ ├── inverse_modelling.py # Inverse modelling components (NEW)
│ │ ├── predictors.py # Prediction heads for L-R, reconstruction, classification
│ │ └── reconstruction.py # Gene expression decoders
│ ├── training/ # Training utilities
│ │ ├── losses.py # Multi-task loss functions (including inverse losses)
│ │ ├── metrics.py # Evaluation metrics
│ │ ├── trainer.py # Training loop and checkpointing
│ │ └── contrastive.py # Contrastive learning modules
│ ├── utils/ # Utility functions
│ └── visualization/ # Spatial visualization tools
│ └── spatial_viz.py # Network and L-R visualization
├── configs/
│ ├── default.yaml # Default configuration file
│ └── cv.yaml # Cross-validation configuration
├── data/ # Data directory (not in repo)
│ └── HeartCellAtlasv2/ # Heart Cell Atlas v2 datasets
├── outputs/ # Training outputs
│ ├── checkpoints/ # Model checkpoints
│ ├── logs/ # TensorBoard logs
│ ├── analysis/ # Network analysis outputs
│ ├── enhanced_analysis/ # Enhanced inference results
│ └── inverse_analysis/ # Inverse modelling results (NEW)
├── docs/ # Documentation
│ ├── METHODOLOGY.md # Detailed methods (including inverse modelling)
│ └── RESULTS.md # Results and findings
├── notebooks/ # Jupyter notebooks
├── train.py # Standard training script
├── train_cv.py # Cross-validation training script
├── enhanced_inference.py # Enhanced inference pipeline
├── inverse_inference.py # Inverse modelling analysis (NEW)
├── evaluate_test.py # Model evaluation script
├── check_checkpoint.py # Checkpoint inspection utility
└── README.md # This file
Installation
Quick Install (Python Package)
pip install grail-heart
Or install from source:
git clone https://github.com/tumo505/GRAIL-Heart.git
cd GRAIL-Heart
pip install -e .
Requirements
- Python 3.9+
- CUDA 11.0+ (for GPU acceleration)
- pip or conda
Full Setup (Development)
- Clone the repository:
git clone https://github.com/Tumo505/GRAIL-Heart.git
cd GRAIL-Heart
- Create a Python virtual environment:
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
- Install dependencies:
pip install torch torchvision # Install PyTorch with CUDA support
pip install torch-geometric
pip install omnipath # For L-R database
pip install -e ".[all]" # Install with all extras
Quick Start
Python API
from grail_heart import load_pretrained
# Load pretrained model
model = load_pretrained()
# Run forward modeling (Expression → L-R predictions)
results = model.predict("my_cardiac_data.h5ad", mode="forward")
print(results.top_lr_pairs.head(10))
# Run inverse modeling (Fate → Causal L-R signals)
results = model.predict("my_cardiac_data.h5ad", mode="inverse")
print(results.causal_scores.head(10))
# Export results
results.to_csv("lr_predictions.csv")
results.to_json("network.json")
Command Line Interface
# Run prediction
grail-heart predict my_data.h5ad --output results.csv
# Run inverse modeling
grail-heart predict my_data.h5ad --mode inverse --output causal_results.csv
# Show model info
grail-heart info
# Start web application
grail-heart app
Web Application
Start the interactive web app for uploading and analyzing your own data:
# Option 1: Using CLI
grail-heart app
# Option 2: Using Streamlit directly
streamlit run app/app.py
Then open http://localhost:8501 in your browser.
Features:
- Upload scRNA-seq data (h5ad, h5, CSV formats)
- Run forward or inverse modeling
- Interactive network visualization
- Download results as CSV/JSON
Model Capabilities
Forward Modeling
Expression → L-R Predictions
Given gene expression data from cardiac cells, the model predicts which ligand-receptor interactions are active. This uses:
- Graph neural networks to capture spatial context
- Multi-head attention over cell neighborhoods
- Expression correlation with curated L-R databases
Inverse Modeling
Observed Fates → Causal L-R Signals
The key innovation of GRAIL-Heart. Given observed cell differentiation patterns, the model identifies which L-R interactions are causally responsible for driving those fates. This:
- Goes beyond simple expression correlation
- Identifies mechanosensitive pathways
- Links molecular signaling to tissue patterning
Key dependencies:
- PyTorch 2.0+
- PyTorch Geometric
- OmniPath (L-R database access)
- Scanpy (single-cell analysis)
- AnnData (data container)
- Pandas, NumPy
- Matplotlib, Seaborn (visualization)
- PyYAML (configuration)
- TensorBoard (logging)
Data Preparation
Download Heart Cell Atlas v2
Download the Visium spatial transcriptomics files from the Heart Cell Atlas:
# Create data directory
mkdir -p data/HeartCellAtlasv2
# Download datasets (approximately 30GB total)
# Place .h5ad files in data/HeartCellAtlasv2/
# Expected files:
# - visium-OCT_AX_raw.h5ad (Apex)
# - visium-OCT_LA_raw.h5ad (Left Atrium)
# - visium-OCT_LV_raw.h5ad (Left Ventricle)
# - visium-OCT_RA_raw.h5ad (Right Atrium)
# - visium-OCT_RV_raw.h5ad (Right Ventricle)
# - visium-OCT_SP_raw.h5ad (Septum)
Data Format
Input data should be in AnnData format (.h5ad):
adata.X: Expression matrix [n_cells × n_genes]adata.obsm['spatial']: Spatial coordinates [n_cells × 2]adata.obs: Cell metadata (including cell types if available)adata.var: Gene names and features
The framework will automatically:
- Select top 2,000 highly variable genes
- Normalize library sizes
- Apply log1p transformation
- Filter cells (min 200 genes) and genes (min 3 cells)
Usage
Training (Forward + Inverse Modelling)
Training jointly optimizes both forward modelling (expression to L-R predictions) and inverse modelling (inferring causal L-R signals that drive cell fates). By default, inverse modelling is enabled.
Standard Training:
python train.py --config configs/default.yaml
Cross-Validation (Recommended):
Run Leave-One-Region-Out cross-validation for robust evaluation:
# Full 6-fold CV with inverse modelling
python train_cv.py --config configs/cv.yaml
# Quick test (specific folds, fewer epochs)
python train_cv.py --config configs/cv.yaml --n_epochs 50 --folds "AX,LA,LV"
# Run specific regions only
python train_cv.py --config configs/cv.yaml --folds "RV,SP"
Standard Training
Train the GRAIL-Heart model with default configuration:
python train.py --config configs/default.yaml
Train with custom data directory:
python train.py --config configs/default.yaml --data_dir /path/to/data
Configuration Parameters
Key parameters in configs/cv.yaml:
model:
hidden_dim: 256 # Embedding dimension
n_gat_layers: 3 # Number of GAT layers
n_heads: 8 # Attention heads
n_edge_types: 2 # Spatial + L-R edges
encoder_dims: [512, 256] # Gene encoder hidden dims
dropout: 0.1 # Dropout rate
decoder_type: residual # Expression decoder type
# Inverse modelling
use_inverse_modelling: true
n_fates: null # Use n_cell_types if null
n_pathways: 20 # Signaling pathways
n_mechano_pathways: 8 # Mechanosensitive pathways
data:
max_lr_pairs: 5000 # Limit L-R pairs (memory optimization)
training:
n_epochs: 100 # Training epochs
learning_rate: 0.0001 # Adam learning rate
weight_decay: 0.01 # L2 regularization
batch_size: 1 # Full graph per batch
grad_clip: 1.0 # Gradient clipping threshold
mixed_precision: true # Use AMP training
loss:
lr_weight: 1.0 # L-R prediction weight
recon_weight: 1.0 # Reconstruction weight
cell_type_weight: 1.0 # Cell type classification weight
contrastive_weight: 0.5 # Contrastive learning weight
# Inverse modelling losses
use_inverse_losses: true
fate_weight: 0.5 # Cell fate prediction weight
causal_weight: 0.3 # Causal sparsity regularization
differentiation_weight: 0.2
gene_target_weight: 0.3
Inference
Run enhanced inference with OmniPath L-R database:
python enhanced_inference.py \
--checkpoint outputs/checkpoints/best.pt \
--data_dir data/HeartCellAtlasv2 \
--output_dir outputs/enhanced_analysis
This generates:
- L-R interaction scores (CSV tables)
- Spatial network visualizations (PNG figures)
- Interaction networks (JSON files)
- Cross-region comparison analysis
Inverse Modelling Analysis
Run inverse modelling to identify causal L-R signals driving cell fate:
python inverse_inference.py \
--checkpoint outputs/checkpoints/best.pt \
--data_dir data/HeartCellAtlasv2 \
--output_dir outputs/inverse_analysis
This generates:
- Causal L-R Rankings: Top L-R pairs driving each cell fate
- Mechanosensitive Pathway Activation: YAP/TAZ, Integrin, Piezo pathway scores
- Cell Fate Trajectories: Differentiation predictions per cell
- Network Visualizations: Causal signalling network graphs
Enhanced Inference Results:
Top causal L-R interactions identified across all cardiac regions:
| Region | Top Causal Interaction | Causal Score | Pathway |
|---|---|---|---|
| AX | TIMP1→MMP2 | 1.869 | ECM Regulator |
| LA | SERPING1→C1S | 1.834 | Complement |
| LV | CFD→C3 | 1.857 | Complement |
| RA | TIMP2→MMP2 | 1.831 | ECM Regulator |
| RV | TIMP1→MMP2 | 1.844 | ECM Regulator |
| SP | THBS1→FN1 | 1.818 | ECM |
# Programmatic inverse inference
from grail_heart.models import GRAILHeart
model = GRAILHeart.load_from_checkpoint("outputs/checkpoints/best.pt")
# Run inverse modelling
results = model.infer_causal_signals(data, target_fate=0) # e.g., cardiomyocyte fate
print(results['causal_lr_rankings'][:10]) # Top 10 causal L-R pairs
print(results['mechano_pathway_activation']) # Mechanosensitive pathway scores
Evaluation
Evaluate model on test set:
python evaluate_test.py
Ligand-Receptor Database
OmniPath Integration
GRAIL-Heart integrates with OmniPath, providing access to multiple curated L-R databases:
| Source Database | Description |
|---|---|
| CellPhoneDB | Comprehensive L-R interactions |
| CellChat | Cell-cell communication database |
| ICELLNET | Intercellular communication |
| Ramilowski2015 | Literature-curated interactions |
| And more... | 10+ integrated sources |
Database Statistics:
- Raw interactions: 115,064
- Unique L-R pairs: 22,234
- Unique ligands: 2,284
- Unique receptors: 2,637
- Pathway categories: 16
Usage
from grail_heart.data.cellchat_database import get_omnipath_lr_database
# Load full database
lr_pairs = get_omnipath_lr_database()
print(f"Loaded {len(lr_pairs)} L-R pairs")
# With caching
lr_pairs = get_omnipath_lr_database(cache_path="data/lr_database_cache.csv")
Model Architecture
Overview
Input (Gene expression + Spatial coordinates)
|
v
Gene Expression Encoder [512 -> 256]
|
+-----------> Spatial Position Encoder [2D -> 64D]
|
v
Multi-Modal Encoder (concatenate + project)
|
v
Graph Attention Network Stack (3 layers, 8 heads)
|
v
Jumping Knowledge Concatenation
|
+---------> L-R Interaction Head (Bilinear)
+---------> Gene Expression Decoder (Residual)
+---------> Cell Type Classifier
+---------> Signaling Network Predictor
Multi-Task Learning
Total loss function:
L_total = w_lr * L_lr + w_recon * L_recon + w_ct * L_ct + w_contr * L_contr
where:
- L_lr: Focal binary cross-entropy for L-R prediction
- L_recon: Combined MSE + Cosine + Correlation loss
- L_ct: Cross-entropy for cell type classification
- L_contr: InfoNCE contrastive learning loss
Outputs
Cross-Validation Outputs
outputs/cv_TIMESTAMP/
├── config.yaml # Configuration used
├── cv_results.yaml # Aggregated CV metrics
├── cv_results.json # JSON format results
├── fold_0_AX/ # Fold 0 (held out AX)
│ ├── checkpoints/
│ │ └── best.pt
│ └── val_metrics.yaml
├── fold_1_LA/ # Fold 1 (held out LA)
├── ...
└── fold_5_SP/ # Fold 5 (held out SP)
Inference Outputs
outputs/enhanced_analysis/
├── analysis_report.txt # Summary with causal analysis
├── tables/
│ ├── AX_lr_scores.csv
│ ├── LA_lr_scores.csv
│ ├── ... (one per region)
│ └── cross_region_comparison.csv
├── figures/
│ ├── AX_spatial_network.png
│ ├── AX_lr_heatmap.png
│ ├── AX_pathway_activity.png
│ ├── ... (multiple per region, 56 interaction figures total)
│ ├── cross_region_lr_heatmap.png
│ ├── region_comparison_panels.png
│ └── network_summary_dashboard.png
├── causal_analysis/ # Inverse modelling outputs
│ ├── AX_causal_edges.csv # Per-edge causal scores
│ └── ... (one per region)
└── networks/
├── AX_network.json
└── ... (one per region)
Documentation
Detailed documentation is available in the docs/ directory:
- docs/METHODOLOGY.md: Comprehensive description of methods, model architecture, cross-validation strategy, and L-R database
- docs/RESULTS.md: Detailed results, CV metrics, biological findings, and figures
Troubleshooting
GPU Memory Issues
If you encounter out-of-memory errors:
-
Use CV config with limited L-R pairs:
python train_cv.py --config configs/cv.yaml # Uses 5000 L-R pairs
-
Reduce L-R pairs further in
configs/cv.yaml:data: max_lr_pairs: 2000 # Smaller model
-
Reduce model hidden dimension or number of GAT layers
Missing Data Files
Ensure all required .h5ad files are in data/HeartCellAtlasv2/ directory with correct naming:
- visium-OCT_AX_raw.h5ad
- visium-OCT_LA_raw.h5ad
- visium-OCT_LV_raw.h5ad
- visium-OCT_RA_raw.h5ad
- visium-OCT_RV_raw.h5ad
- visium-OCT_SP_raw.h5ad
CUDA Errors
If CUDA is not detected:
# Verify CUDA installation
python -c "import torch; print(torch.cuda.is_available())"
# Force CPU training (slower)
# Edit configs/cv.yaml: hardware.device: cpu
Acknowledgments
- Heart Cell Atlas v2 dataset
- OmniPath database (CellPhoneDB, CellChat, ICELLNET)
- PyTorch Geometric framework
- Scanpy and AnnData communities
Citation
If you use GRAIL-Heart in your research, please cite:
@software{grail_heart,
author = {Tumo Kgabeng, Lulu Wang, Harry Ngwangwa, Thanyani Pandelani},
title = {GRAIL-Heart: Graph-based Reconstruction of Artificial Intercellular Links},
year = {2026},
url = {https://github.com/Tumo505/GRAIL-Heart}
}
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file grail_heart-1.0.0.tar.gz.
File metadata
- Download URL: grail_heart-1.0.0.tar.gz
- Upload date:
- Size: 130.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a7bb3e3beb5a059a59848df1f0475e53491189b17e802805b502f3ae951ebc93
|
|
| MD5 |
bf8ed7c155893b8f97da90a9a96bd558
|
|
| BLAKE2b-256 |
6bbdde12bd9bc2beb46234c4026ff0902e9fa423cf2fe73a243438c3f97351a5
|
File details
Details for the file grail_heart-1.0.0-py3-none-any.whl.
File metadata
- Download URL: grail_heart-1.0.0-py3-none-any.whl
- Upload date:
- Size: 94.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3119517a78687940e3d5f9072989a711058b0e1d792d2c72e2d99031967b916b
|
|
| MD5 |
22484b2f40c1214ecdd96f5f67c3abdb
|
|
| BLAKE2b-256 |
5d80c2f333c9664a545a57407f4ab94396b57edda53dec77a43784e868b8c9e5
|