Skip to main content

Graph-based Reconstruction of Artificial Intercellular Links for Cardiac Spatial Transcriptomics

Project description

GRAIL-Heart: Graph-based Reconstruction of Artificial Intercellular Links

A Graph Neural Network framework for analyzing cell-cell communication in cardiac spatial transcriptomics data, featuring both forward and inverse modelling capabilities.

Live Demo

Interactive Network Explorer

Explore ligand-receptor interaction networks across all six cardiac regions with our interactive web application.

Overview

GRAIL-Heart is a deep learning model designed to discover and analyze ligand-receptor (L-R) interactions in spatial transcriptomics datasets, with a focus on cardiac tissue. The framework integrates:

  • Gene expression encoding with neural networks
  • Spatial information through positional embeddings
  • Graph attention mechanisms for neighborhood analysis
  • Multi-task learning for simultaneous prediction of L-R interactions, gene expression reconstruction, and cell type classification
  • Inverse modelling for inferring causal L-R signals that drive cell differentiation

Forward vs Inverse Modelling

Modelling Type Input Output Question Answered
Forward Expression + Spatial L-R predictions "Which L-R interactions are active?"
Inverse Observed phenotype Causal L-R signals "What signals drove this differentiation?"

The model is trained on the Heart Cell Atlas v2, comprising spatial transcriptomics data from six distinct cardiac regions (Apex, Left Atrium, Left Ventricle, Right Atrium, Right Ventricle, and Septum).

Key Features

  • Multi-task learning framework balancing L-R prediction, reconstruction, and classification
  • Edge-type aware Graph Attention Networks with spatial and L-R edge types
  • OmniPath L-R database integration (22,000+ curated pairs from CellPhoneDB, CellChat, ICELLNET)
  • Leave-One-Region-Out (LORO) cross-validation for robust generalization assessment
  • Inverse modelling framework for causal L-R inference
  • Mechanosensitive pathway analysis (YAP/TAZ, Integrin-FAK, Piezo, TGF-β)
  • Comprehensive spatial visualization of cell-cell communication networks
  • Mixed precision training for efficient GPU utilization
  • Complete inference pipeline with cross-region analysis

Cross-Validation Results

GRAIL-Heart was evaluated using 6-fold Leave-One-Region-Out cross-validation:

Metric Mean ± Std Interpretation
Reconstruction R² 0.885 0.114 Excellent gene expression reconstruction
Pearson Correlation 0.991 0.005 Near-perfect correlation
L-R AUROC 0.786 0.202 Good L-R prediction (↑4.6% with inverse modelling)
L-R AUPRC 0.974 0.030 Excellent precision-recall
Accuracy 0.927 0.059 Very high classification accuracy
F1 Score 0.959 0.035 Excellent balance

Per-Region Performance

Region AUROC AUPRC Top Causal L-R
RV (Right Ventricle) 0.968 0.984 0.999 TIMP1→MMP2 (1.844)
LA (Left Atrium) 0.960 0.938 0.993 SERPING1→C1S (1.834)
AX (Apex) 0.965 0.888 0.979 TIMP1→MMP2 (1.869)
LV (Left Ventricle) 0.969 0.861 0.962 CFD→C3 (1.857)
RA (Right Atrium) 0.727 0.408 0.914 TIMP2→MMP2 (1.831)
SP (Septum) 0.720 0.634 0.997 THBS1→FN1 (1.818)

Project Structure

GRAIL-Heart/
├── src/grail_heart/                  # Main package
│   ├── data/                         # Data loading and preprocessing
│   │   ├── datasets.py              # Spatial transcriptomics dataset loader
│   │   ├── graph_builder.py         # Spatial graph construction (kNN, radius, Delaunay)
│   │   ├── lr_database.py           # Standard L-R database
│   │   ├── expanded_lr_database.py  # Extended database with 500+ pairs
│   │   └── cellchat_database.py     # OmniPath integration (22,000+ pairs)
│   ├── models/                       # Neural network architectures
│   │   ├── encoders.py              # Gene and spatial encoders
│   │   ├── gat_layers.py            # Graph Attention layers with edge type awareness
│   │   ├── grail_heart.py           # Main GRAIL-Heart model (forward + inverse)
│   │   ├── inverse_modelling.py     # Inverse modelling components (NEW)
│   │   ├── predictors.py            # Prediction heads for L-R, reconstruction, classification
│   │   └── reconstruction.py        # Gene expression decoders
│   ├── training/                     # Training utilities
│   │   ├── losses.py                # Multi-task loss functions (including inverse losses)
│   │   ├── metrics.py               # Evaluation metrics
│   │   ├── trainer.py               # Training loop and checkpointing
│   │   └── contrastive.py           # Contrastive learning modules
│   ├── utils/                        # Utility functions
│   └── visualization/                # Spatial visualization tools
│       └── spatial_viz.py           # Network and L-R visualization
├── configs/
│   ├── default.yaml                  # Default configuration file
│   └── cv.yaml                       # Cross-validation configuration
├── data/                             # Data directory (not in repo)
│   └── HeartCellAtlasv2/            # Heart Cell Atlas v2 datasets
├── outputs/                          # Training outputs
│   ├── checkpoints/                 # Model checkpoints
│   ├── logs/                        # TensorBoard logs
│   ├── analysis/                    # Network analysis outputs
│   ├── enhanced_analysis/           # Enhanced inference results
│   └── inverse_analysis/            # Inverse modelling results (NEW)
├── docs/                             # Documentation
│   ├── METHODOLOGY.md               # Detailed methods (including inverse modelling)
│   └── RESULTS.md                   # Results and findings
├── notebooks/                        # Jupyter notebooks
├── train.py                          # Standard training script
├── train_cv.py                       # Cross-validation training script
├── enhanced_inference.py             # Enhanced inference pipeline
├── inverse_inference.py              # Inverse modelling analysis (NEW)
├── evaluate_test.py                  # Model evaluation script
├── check_checkpoint.py               # Checkpoint inspection utility
└── README.md                         # This file

Installation

Quick Install (Python Package)

pip install grail-heart

Or install from source:

git clone https://github.com/tumo505/GRAIL-Heart.git
cd GRAIL-Heart
pip install -e .

Requirements

  • Python 3.9+
  • CUDA 11.0+ (for GPU acceleration)
  • pip or conda

Full Setup (Development)

  1. Clone the repository:
git clone https://github.com/Tumo505/GRAIL-Heart.git
cd GRAIL-Heart
  1. Create a Python virtual environment:
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate
  1. Install dependencies:
pip install torch torchvision  # Install PyTorch with CUDA support
pip install torch-geometric
pip install omnipath           # For L-R database
pip install -e ".[all]"        # Install with all extras

Quick Start

Python API

from grail_heart import load_pretrained

# Load pretrained model
model = load_pretrained()

# Run forward modeling (Expression → L-R predictions)
results = model.predict("my_cardiac_data.h5ad", mode="forward")
print(results.top_lr_pairs.head(10))

# Run inverse modeling (Fate → Causal L-R signals)  
results = model.predict("my_cardiac_data.h5ad", mode="inverse")
print(results.causal_scores.head(10))

# Export results
results.to_csv("lr_predictions.csv")
results.to_json("network.json")

Command Line Interface

# Run prediction
grail-heart predict my_data.h5ad --output results.csv

# Run inverse modeling
grail-heart predict my_data.h5ad --mode inverse --output causal_results.csv

# Show model info
grail-heart info

# Start web application
grail-heart app

Web Application

Start the interactive web app for uploading and analyzing your own data:

# Option 1: Using CLI
grail-heart app

# Option 2: Using Streamlit directly
streamlit run app/app.py

Then open http://localhost:8501 in your browser.

Features:

  • Upload scRNA-seq data (h5ad, h5, CSV formats)
  • Run forward or inverse modeling
  • Interactive network visualization
  • Download results as CSV/JSON

Model Capabilities

Forward Modeling

Expression → L-R Predictions

Given gene expression data from cardiac cells, the model predicts which ligand-receptor interactions are active. This uses:

  • Graph neural networks to capture spatial context
  • Multi-head attention over cell neighborhoods
  • Expression correlation with curated L-R databases

Inverse Modeling

Observed Fates → Causal L-R Signals

The key innovation of GRAIL-Heart. Given observed cell differentiation patterns, the model identifies which L-R interactions are causally responsible for driving those fates. This:

  • Goes beyond simple expression correlation
  • Identifies mechanosensitive pathways
  • Links molecular signaling to tissue patterning

Key dependencies:

  • PyTorch 2.0+
  • PyTorch Geometric
  • OmniPath (L-R database access)
  • Scanpy (single-cell analysis)
  • AnnData (data container)
  • Pandas, NumPy
  • Matplotlib, Seaborn (visualization)
  • PyYAML (configuration)
  • TensorBoard (logging)

Data Preparation

Download Heart Cell Atlas v2

Download the Visium spatial transcriptomics files from the Heart Cell Atlas:

# Create data directory
mkdir -p data/HeartCellAtlasv2

# Download datasets (approximately 30GB total)
# Place .h5ad files in data/HeartCellAtlasv2/
# Expected files:
# - visium-OCT_AX_raw.h5ad  (Apex)
# - visium-OCT_LA_raw.h5ad  (Left Atrium)
# - visium-OCT_LV_raw.h5ad  (Left Ventricle)
# - visium-OCT_RA_raw.h5ad  (Right Atrium)
# - visium-OCT_RV_raw.h5ad  (Right Ventricle)
# - visium-OCT_SP_raw.h5ad  (Septum)

Data Format

Input data should be in AnnData format (.h5ad):

  • adata.X: Expression matrix [n_cells × n_genes]
  • adata.obsm['spatial']: Spatial coordinates [n_cells × 2]
  • adata.obs: Cell metadata (including cell types if available)
  • adata.var: Gene names and features

The framework will automatically:

  • Select top 2,000 highly variable genes
  • Normalize library sizes
  • Apply log1p transformation
  • Filter cells (min 200 genes) and genes (min 3 cells)

Usage

Training (Forward + Inverse Modelling)

Training jointly optimizes both forward modelling (expression to L-R predictions) and inverse modelling (inferring causal L-R signals that drive cell fates). By default, inverse modelling is enabled.

Standard Training:

python train.py --config configs/default.yaml

Cross-Validation (Recommended):

Run Leave-One-Region-Out cross-validation for robust evaluation:

# Full 6-fold CV with inverse modelling
python train_cv.py --config configs/cv.yaml

# Quick test (specific folds, fewer epochs)
python train_cv.py --config configs/cv.yaml --n_epochs 50 --folds "AX,LA,LV"

# Run specific regions only
python train_cv.py --config configs/cv.yaml --folds "RV,SP"

Standard Training

Train the GRAIL-Heart model with default configuration:

python train.py --config configs/default.yaml

Train with custom data directory:

python train.py --config configs/default.yaml --data_dir /path/to/data

Configuration Parameters

Key parameters in configs/cv.yaml:

model:
  hidden_dim: 256           # Embedding dimension
  n_gat_layers: 3          # Number of GAT layers
  n_heads: 8               # Attention heads
  n_edge_types: 2          # Spatial + L-R edges
  encoder_dims: [512, 256] # Gene encoder hidden dims
  dropout: 0.1             # Dropout rate
  decoder_type: residual   # Expression decoder type
  # Inverse modelling
  use_inverse_modelling: true
  n_fates: null            # Use n_cell_types if null
  n_pathways: 20           # Signaling pathways
  n_mechano_pathways: 8    # Mechanosensitive pathways

data:
  max_lr_pairs: 5000       # Limit L-R pairs (memory optimization)

training:
  n_epochs: 100            # Training epochs
  learning_rate: 0.0001    # Adam learning rate
  weight_decay: 0.01       # L2 regularization
  batch_size: 1            # Full graph per batch
  grad_clip: 1.0           # Gradient clipping threshold
  mixed_precision: true    # Use AMP training

loss:
  lr_weight: 1.0           # L-R prediction weight
  recon_weight: 1.0        # Reconstruction weight
  cell_type_weight: 1.0    # Cell type classification weight
  contrastive_weight: 0.5  # Contrastive learning weight
  # Inverse modelling losses
  use_inverse_losses: true
  fate_weight: 0.5         # Cell fate prediction weight
  causal_weight: 0.3       # Causal sparsity regularization
  differentiation_weight: 0.2
  gene_target_weight: 0.3

Inference

Run enhanced inference with OmniPath L-R database:

python enhanced_inference.py \
  --checkpoint outputs/checkpoints/best.pt \
  --data_dir data/HeartCellAtlasv2 \
  --output_dir outputs/enhanced_analysis

This generates:

  • L-R interaction scores (CSV tables)
  • Spatial network visualizations (PNG figures)
  • Interaction networks (JSON files)
  • Cross-region comparison analysis

Inverse Modelling Analysis

Run inverse modelling to identify causal L-R signals driving cell fate:

python inverse_inference.py \
  --checkpoint outputs/checkpoints/best.pt \
  --data_dir data/HeartCellAtlasv2 \
  --output_dir outputs/inverse_analysis

This generates:

  • Causal L-R Rankings: Top L-R pairs driving each cell fate
  • Mechanosensitive Pathway Activation: YAP/TAZ, Integrin, Piezo pathway scores
  • Cell Fate Trajectories: Differentiation predictions per cell
  • Network Visualizations: Causal signalling network graphs

Enhanced Inference Results:

Top causal L-R interactions identified across all cardiac regions:

Region Top Causal Interaction Causal Score Pathway
AX TIMP1→MMP2 1.869 ECM Regulator
LA SERPING1→C1S 1.834 Complement
LV CFD→C3 1.857 Complement
RA TIMP2→MMP2 1.831 ECM Regulator
RV TIMP1→MMP2 1.844 ECM Regulator
SP THBS1→FN1 1.818 ECM
# Programmatic inverse inference
from grail_heart.models import GRAILHeart

model = GRAILHeart.load_from_checkpoint("outputs/checkpoints/best.pt")

# Run inverse modelling
results = model.infer_causal_signals(data, target_fate=0)  # e.g., cardiomyocyte fate
print(results['causal_lr_rankings'][:10])  # Top 10 causal L-R pairs
print(results['mechano_pathway_activation'])  # Mechanosensitive pathway scores

Evaluation

Evaluate model on test set:

python evaluate_test.py

Ligand-Receptor Database

OmniPath Integration

GRAIL-Heart integrates with OmniPath, providing access to multiple curated L-R databases:

Source Database Description
CellPhoneDB Comprehensive L-R interactions
CellChat Cell-cell communication database
ICELLNET Intercellular communication
Ramilowski2015 Literature-curated interactions
And more... 10+ integrated sources

Database Statistics:

  • Raw interactions: 115,064
  • Unique L-R pairs: 22,234
  • Unique ligands: 2,284
  • Unique receptors: 2,637
  • Pathway categories: 16

Usage

from grail_heart.data.cellchat_database import get_omnipath_lr_database

# Load full database
lr_pairs = get_omnipath_lr_database()
print(f"Loaded {len(lr_pairs)} L-R pairs")

# With caching
lr_pairs = get_omnipath_lr_database(cache_path="data/lr_database_cache.csv")

Model Architecture

Overview

Input (Gene expression + Spatial coordinates)
    |
    v
Gene Expression Encoder [512 -> 256]
    |
    +-----------> Spatial Position Encoder [2D -> 64D]
    |
    v
Multi-Modal Encoder (concatenate + project)
    |
    v
Graph Attention Network Stack (3 layers, 8 heads)
    |
    v
Jumping Knowledge Concatenation
    |
    +---------> L-R Interaction Head (Bilinear)
    +---------> Gene Expression Decoder (Residual)
    +---------> Cell Type Classifier
    +---------> Signaling Network Predictor

Multi-Task Learning

Total loss function:

L_total = w_lr * L_lr + w_recon * L_recon + w_ct * L_ct + w_contr * L_contr

where:
- L_lr: Focal binary cross-entropy for L-R prediction
- L_recon: Combined MSE + Cosine + Correlation loss
- L_ct: Cross-entropy for cell type classification
- L_contr: InfoNCE contrastive learning loss

Outputs

Cross-Validation Outputs

outputs/cv_TIMESTAMP/
├── config.yaml              # Configuration used
├── cv_results.yaml          # Aggregated CV metrics
├── cv_results.json          # JSON format results
├── fold_0_AX/               # Fold 0 (held out AX)
│   ├── checkpoints/
│   │   └── best.pt
│   └── val_metrics.yaml
├── fold_1_LA/               # Fold 1 (held out LA)
├── ...
└── fold_5_SP/               # Fold 5 (held out SP)

Inference Outputs

outputs/enhanced_analysis/
├── analysis_report.txt      # Summary with causal analysis
├── tables/
│   ├── AX_lr_scores.csv
│   ├── LA_lr_scores.csv
│   ├── ... (one per region)
│   └── cross_region_comparison.csv
├── figures/
│   ├── AX_spatial_network.png
│   ├── AX_lr_heatmap.png
│   ├── AX_pathway_activity.png
│   ├── ... (multiple per region, 56 interaction figures total)
│   ├── cross_region_lr_heatmap.png
│   ├── region_comparison_panels.png
│   └── network_summary_dashboard.png
├── causal_analysis/         # Inverse modelling outputs
│   ├── AX_causal_edges.csv  # Per-edge causal scores
│   └── ... (one per region)
└── networks/
    ├── AX_network.json
    └── ... (one per region)

Documentation

Detailed documentation is available in the docs/ directory:

  • docs/METHODOLOGY.md: Comprehensive description of methods, model architecture, cross-validation strategy, and L-R database
  • docs/RESULTS.md: Detailed results, CV metrics, biological findings, and figures

Troubleshooting

GPU Memory Issues

If you encounter out-of-memory errors:

  1. Use CV config with limited L-R pairs:

    python train_cv.py --config configs/cv.yaml  # Uses 5000 L-R pairs
    
  2. Reduce L-R pairs further in configs/cv.yaml:

    data:
      max_lr_pairs: 2000  # Smaller model
    
  3. Reduce model hidden dimension or number of GAT layers

Missing Data Files

Ensure all required .h5ad files are in data/HeartCellAtlasv2/ directory with correct naming:

  • visium-OCT_AX_raw.h5ad
  • visium-OCT_LA_raw.h5ad
  • visium-OCT_LV_raw.h5ad
  • visium-OCT_RA_raw.h5ad
  • visium-OCT_RV_raw.h5ad
  • visium-OCT_SP_raw.h5ad

CUDA Errors

If CUDA is not detected:

# Verify CUDA installation
python -c "import torch; print(torch.cuda.is_available())"

# Force CPU training (slower)
# Edit configs/cv.yaml: hardware.device: cpu

Acknowledgments

  • Heart Cell Atlas v2 dataset
  • OmniPath database (CellPhoneDB, CellChat, ICELLNET)
  • PyTorch Geometric framework
  • Scanpy and AnnData communities

Citation

If you use GRAIL-Heart in your research, please cite:

@software{grail_heart,
  author = {Tumo Kgabeng, Lulu Wang, Harry Ngwangwa, Thanyani Pandelani},
  title = {GRAIL-Heart: Graph-based Reconstruction of Artificial Intercellular Links},
  year = {2026},
  url = {https://github.com/Tumo505/GRAIL-Heart}
}

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grail_heart-1.0.0.tar.gz (130.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

grail_heart-1.0.0-py3-none-any.whl (94.6 kB view details)

Uploaded Python 3

File details

Details for the file grail_heart-1.0.0.tar.gz.

File metadata

  • Download URL: grail_heart-1.0.0.tar.gz
  • Upload date:
  • Size: 130.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for grail_heart-1.0.0.tar.gz
Algorithm Hash digest
SHA256 a7bb3e3beb5a059a59848df1f0475e53491189b17e802805b502f3ae951ebc93
MD5 bf8ed7c155893b8f97da90a9a96bd558
BLAKE2b-256 6bbdde12bd9bc2beb46234c4026ff0902e9fa423cf2fe73a243438c3f97351a5

See more details on using hashes here.

File details

Details for the file grail_heart-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: grail_heart-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 94.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for grail_heart-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3119517a78687940e3d5f9072989a711058b0e1d792d2c72e2d99031967b916b
MD5 22484b2f40c1214ecdd96f5f67c3abdb
BLAKE2b-256 5d80c2f333c9664a545a57407f4ab94396b57edda53dec77a43784e868b8c9e5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page