Skip to main content

Lorentzian Interpretable Variational Autoencoder for single-cell data

Project description

LiVAE: Lorentzian Interpretable Variational Autoencoder

Python 3.10+ PyTorch License

LiVAE (Lorentzian Interpretable Variational Autoencoder) learns interpretable latent representations for single-cell RNA-seq data using Lorentzian (hyperbolic) geometry and multi-component regularization.

Installation

# Clone repository and install dependencies
git clone https://github.com/PeterPonyu/LiVAE.git
cd LiVAE
pip install -r requirements.txt
pip install -e .

Note: PyPI package publication is pending. Install from source for now.

Core Requirements

torch>=2.3.0,<2.5.0
torchvision>=0.18.0,<0.20.0
scanpy>=1.10.0,<1.11.0
scvi-tools>=1.1.0,<1.2.0
anndata>=0.10.0,<0.11.0
scib>=1.0.0
numpy>=1.26.0,<1.27.0
pandas>=2.2.0,<2.3.0
scipy>=1.12.0,<1.13.0
scikit-learn>=1.5.0,<1.6.0
tqdm>=4.66.0,<5.0.0
fastapi>=0.117.0,<0.118.0
uvicorn[standard]>=0.36.0,<0.37.0
python-multipart>=0.0.6

Python Quick Start

import scanpy as sc
from livae import agent

# Load your single-cell data
adata = sc.read_h5ad('your_data.h5ad')

# Initialize LiVAE agent
model = agent(
    adata=adata,
    layer='counts',         # Data layer to use
    latent_dim=10,          # Primary latent dimension
    i_dim=2,                # Interpretable latent dimension
    hidden_dim=128,         # Hidden layer dimension
    lr=1e-4,                # Learning rate
    # Regularization weights
    beta=1.0,               # β-VAE weight
    lorentz=1.0,            # Lorentzian regularization
    irecon=1.0,             # Interpretable reconstruction
)

# Train the model
model.fit(epochs=100)

# Extract embeddings
latent = model.get_latent()          # Primary latent representation
interpretable = model.get_iembed()   # Interpretable embedding (i_dim)

# Access results
print(f"Latent shape: {latent.shape}")
print(f"Interpretable embedding shape: {interpretable.shape}")

Web Interface

LiVAE includes an integrated web interface for interactive training and visualization.

Launch the Application

# Start the FastAPI server (serves both API and frontend)
uvicorn api.main:app --host 0.0.0.0 --port 8000

Access the web interface at http://localhost:8000

Features

  • Data Upload: Upload single-cell datasets (H5AD format)
  • Training Configuration: Configure model parameters and hyperparameters
  • Real-time Monitoring: Track loss curves and metrics during training
  • Results Visualization: Download embeddings and view training summaries

API Endpoints

  • POST /upload - Upload dataset
  • POST /train/start - Start training
  • GET /train/metrics - Get training metrics
  • GET /embeddings/interpretable - Get interpretable embeddings
  • GET /embeddings/latent - Get latent embeddings
  • GET /download/embeddings/{type} - Download embeddings as CSV

Architecture Overview

LiVAE consists of three main components working in concert:

1. Encoder Network

Input (gene expression) → Hidden → Hidden → μ, σ (latent parameters)
  • Maps high-dimensional gene expression to latent distribution parameters
  • Uses reparameterization trick for differentiable sampling

2. Hyperbolic Transformation

Latent Sample → Tangent Space → Exponential Map → Lorentzian Manifold
  • Projects latent vectors onto hyperbolic manifold
  • Enables natural representation of hierarchical relationships

3. Dual Decoder Pathway

Primary:      Latent → Decoder → Reconstruction
Interpretable: Latent → Compress → Decompress → Decoder → Reconstruction
  • Dual reconstruction paths for enhanced representation learning
  • Compression bottleneck encourages essential feature extraction

Loss Function

LiVAE optimizes a composite loss function:

L_total = L_recon + L_irecon + L_lorentz + L_KL

Where:

  • L_recon: Negative binomial reconstruction loss
  • L_irecon: Interpretable reconstruction loss
  • L_lorentz: Lorentz distance regularization
  • L_KL: KL divergence

Evaluation Metrics

Clustering Quality

  • ARI (Adjusted Rand Index): Clustering agreement with ground truth
  • NMI (Normalized Mutual Information): Information-theoretic clustering measure
  • ASW (Average Silhouette Width): Cluster cohesion and separation

Cluster Validity

  • C_H (Calinski-Harabasz): Ratio of between to within cluster variance
  • D_B (Davies-Bouldin): Average similarity between clusters
  • P_C (Graph Connectivity): Connectivity within clusters

Batch Integration

  • cLISI: Cluster-specific Local Inverse Simpson Index
  • iLISI: Integration LISI for batch mixing
  • bASW: Batch-corrected Average Silhouette Width

Performance Tips

Memory Optimization

# For large datasets, reduce batch size
model = agent(adata, percent=0.005)  # Use 0.5% of data per batch

# Use CPU if GPU memory is limited
import torch
model = agent(adata, device=torch.device('cpu'))

Training Efficiency

# Progressive training with increasing regularization
model = agent(adata, beta=0.1, lorentz=0.0)
model.fit(epochs=500)

# Increase regularization for fine-tuning
model.beta = 1.0
model.lorentz = 0.1
model.fit(epochs=500)

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

livae-0.2.0.tar.gz (18.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

livae-0.2.0-py3-none-any.whl (17.7 kB view details)

Uploaded Python 3

File details

Details for the file livae-0.2.0.tar.gz.

File metadata

  • Download URL: livae-0.2.0.tar.gz
  • Upload date:
  • Size: 18.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for livae-0.2.0.tar.gz
Algorithm Hash digest
SHA256 a6a32a3cd93ee16a84eb7dc6bbb22f217fb68b762a03e394ccbdf888c5ff54bc
MD5 a3175d69bc78a39ca360ce88c43e387d
BLAKE2b-256 c9db86f4019a3fc7c75e49c84eab7c55889a0b0f93a8fc5139046408fe21908f

See more details on using hashes here.

File details

Details for the file livae-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: livae-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 17.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for livae-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7cbe58c1ba2791f8ba6c92ab482a65855b2f30fdf0b0422b30910c8b351a6037
MD5 4feb28590315d072770ecbb80d0ca7c6
BLAKE2b-256 57a013a7c8bee141b287a6fd58409fffa168779cd5bb5442c5e0ecef1dc9a5e6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page