Skip to main content

A Graph-Based ODE-VAE Enhances Clustering for Single-Cell Data

Project description

GNODEVAE: A Graph-Based ODE-VAE Enhances Clustering of Single-Cell Data

Graphical Abstract

GNODEVAE Graphical Abstract

Introduction

GNODEVAE is an innovative computational framework that integrates Graph Attention Networks (GAT), Neural Ordinary Differential Equations (NODE), and Variational Autoencoders (VAE). It addresses three critical challenges in single-cell RNA sequencing data analysis:

  1. Capturing complex topological relationships between cells
  2. Modeling continuous dynamic processes of cell differentiation
  3. Handling high levels of technical noise and biological variation

This novel integration significantly improves the accurate identification of cell subpopulations, reconstruction of developmental trajectories, and understanding of cellular heterogeneity.

Key Contributions

1. Dynamic Attention Weighting for Biological Significance

The GAT's attention mechanism adaptively weights gene expression profiles, prioritizing meaningful biological relationships while minimizing technical noise - particularly valuable for heterogeneous cell populations.

2. Continuous-Time Developmental Modeling via Neural ODEs

Integration of neural ordinary differential equations transforms static representations into dynamic systems, with time variables providing natural parameterization of developmental processes and enabling predictions at any point in cellular development.

3. Biologically Consistent Latent Space Representations

The model's latent space effectively captures biological phenomena like varying rates of cell differentiation, while attention weights align with established developmental relationships between cell types.

4. Comprehensive Benchmark Leadership

When compared with six advanced single-cell analysis methods (scVI, DIP-VAE, TC-VAE, β-VAE, Info-VAE, and scTour), GNODEVAE ranked first across all 13 test datasets, demonstrating exceptional versatility across diverse biological contexts.

5. Superior Gene Trend Analysis Performance

Quantitative evaluation shows GNODEVAE significantly outperforms existing methods (69.97% improvement over Palantir, 63.58% over Diffmap) in Calinski-Harabasz index, demonstrating clearer clustering and stronger category discrimination.

Installation

Prerequisites

  • Python 3.8 or higher
  • PyTorch 1.12 or higher (with CUDA support recommended for GPU acceleration)
  • PyTorch Geometric

Install from Source

# Clone the repository
git clone https://github.com/PeterPonyu/GNODEVAE.git
cd GNODEVAE

# Install dependencies
pip install torch torch-geometric scanpy anndata numpy pandas scikit-learn tqdm psutil torchdiffeq

Dependencies

The main dependencies include:

  • torch - PyTorch deep learning framework
  • torch-geometric - Geometric deep learning extension for PyTorch
  • scanpy - Single-cell analysis toolkit
  • anndata - Annotated data structures for single-cell data
  • torchdiffeq - Differentiable ODE solvers for PyTorch
  • numpy, pandas - Data manipulation
  • scikit-learn - Machine learning utilities
  • tqdm - Progress bars
  • psutil - System resource monitoring

Quick Start

Basic Usage

import scanpy as sc
from GNODEVAE import agent_r  # GraphVAE with refined architecture
# OR
from GNODEVAE import agent  # Standard GraphVAE
# For full GNODEVAE with ODE support, use:
# from GNODEVAE.GODEVAE_agent import GNODEVAE_agent_r

# Load your single-cell data
adata = sc.read_h5ad('your_data.h5ad')

# Initialize the GNODEVAE agent
model = agent_r(
    adata=adata,
    layer='counts',           # Layer containing count data
    n_var=2000,              # Number of highly variable genes
    tech='PCA',              # Dimensionality reduction technique
    n_neighbors=15,          # Number of neighbors for graph construction
    latent_dim=10,           # Latent space dimension
    hidden_dim=128,          # Hidden layer dimension
    encoder_type='graph',    # Use graph encoder
    graph_type='GAT',        # Graph Attention Network
    lr=1e-4,                 # Learning rate
    device='cuda'            # Use GPU if available
)

# Train the model
model.fit(epochs=300, update_steps=10, silent=False)

# Extract latent representations
latent = model.get_latent()

# Store latent representation in AnnData object
adata.obsm['X_gnodevae'] = latent

# Perform downstream analysis (e.g., clustering)
import scanpy as sc
sc.pp.neighbors(adata, use_rep='X_gnodevae')
sc.tl.leiden(adata)
sc.tl.umap(adata)

Using Standard GraphVAE (without ODE)

from GNODEVAE import agent

# Initialize standard GraphVAE agent
model = agent(
    adata=adata,
    layer='counts',
    n_var=2000,
    tech='PCA',
    n_neighbors=15,
    latent_dim=10,
    hidden_dim=128,
    encoder_type='GAT',
    lr=1e-4
)

# Train and extract embeddings
model.fit(epochs=300)
latent = model.get_latent()

Key Parameters

Data Preprocessing Parameters

  • layer (str): Layer of AnnData to use (default: 'counts')
  • n_var (int): Number of highly variable genes to select (default: None, uses all)
  • tech (str): Dimensionality reduction method - 'PCA', 'NMF', 'FastICA', 'TruncatedSVD', 'FactorAnalysis', or 'LatentDirichletAllocation' (default: 'PCA')
  • n_neighbors (int): Number of neighbors for graph construction (default: 15)
  • batch_tech (str): Batch correction method - 'harmony' or 'scvi' (default: None)
  • all_feat (bool): Whether to use all features or only highly variable genes (default: True)

Model Architecture Parameters

  • hidden_dim (int): Hidden layer dimension (default: 128)
  • latent_dim (int): Latent space dimension for embeddings (default: 10)
  • encoder_type (str): Encoder type - 'graph' or 'linear' (default: 'graph')
  • graph_type (str): Graph convolution type - 'GAT', 'GCN', 'SAGE', 'Transformer', etc. (default: 'GAT')
  • structure_decoder_type (str): Structure decoder type - 'mlp', 'bilinear', or 'inner_product' (default: 'mlp')
  • feature_decoder_type (str): Feature decoder type - 'linear' or 'graph' (default: 'linear')
  • hidden_layers (int): Number of hidden layers (default: 2)
  • dropout (float): Dropout rate (default: 0.05)
  • use_residual (bool): Whether to use residual connections (default: True)

Training Parameters

  • lr (float): Learning rate for optimizer (default: 1e-4)
  • beta (float): Weight for KL divergence loss term (default: 1.0)
  • graph (float): Weight for graph reconstruction loss (default: 1.0)
  • epochs (int): Number of training epochs (default: 300)
  • device (str or torch.device): Computing device - 'cuda' or 'cpu' (default: auto-detect)
  • num_parts (int): Number of graph partitions for mini-batch training (default: 10)

GNODEVAE-Specific Parameters (agent_r)

  • n_ode_hidden (int): Number of hidden units in ODE function (default: varies)
  • w_recon (float): Weight for reconstruction loss (default: 1.0)
  • w_kl (float): Weight for KL divergence loss (default: 1.0)
  • w_adj (float): Weight for adjacency matrix loss (default: 1.0)
  • w_recon_ode (float): Weight for ODE reconstruction loss (default: 1.0)

Model Architecture

GNODEVAE consists of three main components:

  1. Graph Encoder: Encodes cell-cell relationships and gene expression using Graph Attention Networks (GAT) or other graph convolution layers
  2. Neural ODE: Models continuous developmental trajectories in the latent space
  3. Decoder: Reconstructs both graph structure and gene expression from latent representations

The model learns a low-dimensional latent representation that captures:

  • Cell type identity
  • Developmental state
  • Cell-cell relationships
  • Temporal dynamics (with ODE component)

Output

After training, GNODEVAE produces:

  • Latent representations: Low-dimensional embeddings for each cell
  • Clustering metrics: ARI, NMI, Silhouette score, Calinski-Harabasz index, Davies-Bouldin index
  • Pseudo-time: Developmental trajectory information (for agent_r with ODE)
  • Graph structure: Learned cell-cell similarity graph

Advanced Usage

Custom Graph Construction

# Use custom graph construction parameters
model = agent_r(
    adata=adata,
    n_neighbors=30,      # Increase neighbors for denser graph
    graph_type='Transformer',  # Use Transformer convolution
    alpha=0.5            # Set alpha for specific layers
)

Interpretable Mode

# Use interpretable GraphVAE
model = agent(
    adata=adata,
    interpretable=True,  # Enable interpretable mode
    idim=2              # Interpretable dimension
)

Extract Pseudo-time

# For GNODEVAE models with ODE component
# Note: Use GNODEVAE_agent_r from GODEVAE_agent module for pseudo-time functionality
from GNODEVAE.GODEVAE_agent import GNODEVAE_agent_r

model = GNODEVAE_agent_r(adata=adata, ...)
model.fit(epochs=300)

# Get pseudo-time for cells
pseudotime_df = model.partition_time()

Evaluation Metrics

GNODEVAE automatically computes several clustering quality metrics during training:

  • ARI (Adjusted Rand Index): Measures clustering agreement with ground truth
  • NMI (Normalized Mutual Information): Information-theoretic clustering metric
  • ASW (Average Silhouette Width): Measures cluster separation
  • C_H (Calinski-Harabasz Index): Ratio of between-cluster to within-cluster variance
  • D_B (Davies-Bouldin Index): Average similarity between clusters
  • P_C (Pearson Correlation): Correlation between latent dimensions

Citation

If you use GNODEVAE in your research, please cite:

@article{fu2025gnodevae,
  title={GNODEVAE: a graph-based ODE-VAE enhances clustering for single-cell data},
  author={Fu, Z. and Chen, C. and Wang, S. and others},
  journal={BMC Genomics},
  volume={26},
  pages={767},
  year={2025},
  doi={10.1186/s12864-025-11946-7}
}

Full Citation: Fu, Z., Chen, C., Wang, S. et al. GNODEVAE: a graph-based ODE-VAE enhances clustering for single-cell data. BMC Genomics 26, 767 (2025). https://doi.org/10.1186/s12864-025-11946-7

DOI

License

See LICENSE file for details.

Contact

For questions and feedback, please open an issue on the GitHub repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gnodevae-0.0.3.tar.gz (55.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gnodevae-0.0.3-py3-none-any.whl (66.9 kB view details)

Uploaded Python 3

File details

Details for the file gnodevae-0.0.3.tar.gz.

File metadata

  • Download URL: gnodevae-0.0.3.tar.gz
  • Upload date:
  • Size: 55.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for gnodevae-0.0.3.tar.gz
Algorithm Hash digest
SHA256 763d883bd6ff056e249e73148a97799df1f59679f5238a02287b20e3426e23d2
MD5 611b2718ced2cd765b7a4bf259ff141d
BLAKE2b-256 1601aa84ed310c56e860cbc36ad70f57842a65c8f84b1e8989f74f04567c6c8a

See more details on using hashes here.

File details

Details for the file gnodevae-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: gnodevae-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 66.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for gnodevae-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 dafa3e3d68ac5a7602502c60d452082a37f4d2fde1a6a41a7ab9adf67fa17e6b
MD5 577a70e6c9abdeeb4917b78492122300
BLAKE2b-256 edbbbdea15c4d4a4fbed04fe1a87d00f82a4512dd0b3a5d274beb5310dc458b7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page