RiTINI is a Python package for modeling and analyzing gene regulatory networks from single cell trajectories.
Project description
RiTINI: Regulatory Temporal Interaction Network Inference
A Graph Neural Ordinary Differential Equation (ODE) framework for modeling and predicting gene expression trajectories over time. RiTINI combines Graph Attention Networks (GATs) with Neural ODEs to capture the continuous temporal dynamics of gene regulatory networks.
Overview
RiTINI (Regulatory Temporal Interaction Network Inference) is designed to:
- Model temporal gene expression data using graph neural networks
- Learn attention-based gene regulatory networks from trajectory data
- Predict future gene expression states through continuous-time ODEs
- Visualize learned attention patterns and regulatory interactions
Installation
Requirements
- Python >= 3.10
- PyTorch >= 2.8.0
- PyTorch Geometric >= 2.6.1
Setup
- Clone the repository:
git clone git@github.com:KrishnaswamyLab/RiTINI.git
cd RiTINI
- Install dependencies using uv (recommended):
uv venv
source .venv/bin/activate
uv sync
Or using pip:
pip install -e .
Usage
Notebook-Friendly API (pip install)
After installing with pip install ritini, you can run the full workflow directly from Python:
import ritini
# 1) Preprocess raw inputs -> processed trajectory / prior graph files
ritini.preprocess(config_path="configs/config.yaml")
# 2) Train model (alias: ritini.train(...))
checkpoint = ritini.fit(config_path="configs/config.yaml")
# 3a) Focus-gene storyboard visualization
ritini.focus_storyboard(checkpoint_path=checkpoint, focus_gene="G51")
# 3b) Gene trajectory visualizations
ritini.trajectory_viz(checkpoint_path=checkpoint, visualization_config_path="configs/visualization.yaml")
# 3c) Graph inference visualizations
ritini.graph_inference(checkpoint_path=checkpoint, focus_gene="G51")
Running the Full Pipeline
To run preprocessing and training sequentially:
python main.py
Running Steps Independently
Each step can be run as a standalone script, which is useful for iterating on specific stages.
Preprocessing Only
python preprocess.py
Preprocessing performs:
- Load raw trajectory data from
.npyfile - Filter genes based on interest genes list
- Compute prior adjacency matrix (uses Granger Causality by default)
- Average all trajectories into a single representative trajectory
- Normalize features using z-score normalization
- Save preprocessed data to
data/preprocessed/directory
Training Only
python train.py
Training performs:
- Load preprocessed data
- Create temporal graph dataset with sliding time windows
- Initialize and train the RiTINI model
- Apply graph regularization based on prior network
- Save best model based on total loss
- Store training history with all loss components
Trajectory Inference and Visualization
python gene_inference_viz.py
Obtaining the Time varying GRNs
python gene_trajectory_viz.py
Input Data
The preprocessing script requires three input files:
-
Trajectory Data (
raw_trajectory_file):.npyfile containing gene expression trajectories- Shape:
(n_timepoints, n_trajectories, n_genes)
- Shape:
-
Gene Names (
raw_gene_names_file):.txtfile with names of all genes -
Interest Genes (
interest_genes_file):.txtfile with subset of genes to analyze
Default Paths
raw_trajectory_file = 'data/raw/traj_data.npy'
raw_gene_names_file = 'data/raw/gene_names.txt'
interest_genes_file = 'data/raw/interest_genes.txt'
Training on Synthetic Data
python test_toy_data_ritini.py
Model Architecture
RiTINI consists of three main components:
-
GAT Convolutional Layer (gatConvwithAttention.py)
- Multi-head attention mechanism
- Learns edge weights between genes
- Aggregates neighbor information
-
Graph Differential Equation (graphDifferentialEquation.py)
- Wraps GAT layer as ODE function
- Computes derivatives for continuous-time evolution
-
ODE Block (ode.py)
- Integrates dynamics using Neural ODE solvers
- Supports multiple integration methods (RK4, Dopri5, etc.)
- Optional adjoint method for memory-efficient backprop
RiTINI Model Parameters
- Input features: 1 (gene expression value per node)
- Output features: 1 (predicted expression value)
- Architecture: Temporal Graph Attention Network
- Attention mechanism: Multi-head attention with configurable heads
Hyperparameters
Architecture
n_heads = 1 # Number of attention heads
feat_dropout = 0.1 # Feature dropout rate
attn_dropout = 0.1 # Attention dropout rate
activation_func = nn.Tanh() # Activation function
residual = False # Use residual connections
negative_slope = 0.2 # LeakyReLU negative slope
ODE Integration (Model Defaults)
The RiTINI model uses Neural ODEs for continuous-time modeling:
ode_method = 'rk4' # ODE solver (rk4, dopri5, etc.)
atol = 1e-3 # Absolute tolerance
rtol = 1e-4 # Relative tolerance
use_adjoint = False # Use adjoint method for memory efficiency
Training
n_epochs = 200 # Number of training epochs
learning_rate = 0.001 # Initial learning rate
batch_size = 4 # Batch size
time_window = 5 # Temporal window length (None = all timepoints)
Loss Function
graph_reg_weight = 0.1 # Weight for graph regularization loss
Total loss = Feature reconstruction loss + (graph_reg_weight × Graph regularization loss)
Learning Rate Scheduler
lr_factor = 0.5 # LR reduction factor
lr_patience = 10 # Epochs to wait before reducing LR
Configuration
Key hyperparameters in training:
# Data parameters
n_top_genes = 20 # Number of genes to model
time_window = 5 # Length of temporal sequences
batch_size = 4
# Model parameters
n_heads = 1 # GAT attention heads
feat_dropout = 0.1 # Feature dropout rate
attn_dropout = 0.1 # Attention dropout rate
activation_func = nn.Tanh()
residual = False # Residual connections
# Training parameters
n_epochs = 200
learning_rate = 0.001
lr_factor = 0.5 # Scheduler reduction factor
lr_patience = 10 # Scheduler patience
Data Format
Trajectory Data
Expected format: (n_timepoints, n_trajectories, n_genes)
from ritini.data.trajectory_loader import prepare_trajectories_data
data = prepare_trajectories_data(
trajectory_file='data/trajectories/traj_data.pkl',
prior_graph_file='data/trajectories/prior_graph.pkl',
gene_names_file='data/trajectories/gene_names.txt',
n_top_genes=20,
use_mean_trajectory=True
)
Testing
Run the test suite:
# Test on toy data
pytest tests/test_toy_data_ritini.py
# Test on real data
pytest tests/test_real_data_gat.py
# Run all tests
pytest tests/
Dependencies
Core dependencies:
torch >= 2.8.0- Deep learning frameworktorch-geometric >= 2.6.1- Graph neural network librarytorchdiffeq >= 0.2.3- Neural ODE solversnetworkx >= 3.0- Graph manipulationnumpy >= 1.24.0- Numerical computingmatplotlib >= 3.10.6- Plottingseaborn >= 0.13.2- Statistical visualizationscanpy >= 1.11.4- Single-cell analysisscikit-misc >= 0.5.1- Scientific computing utilities
See pyproject.toml for full dependency list.
Citation
If you use RiTINI in your research, please cite:
@misc{https://doi.org/10.48550/arxiv.2306.07803,
doi = {10.48550/ARXIV.2306.07803},
url = {https://arxiv.org/abs/2306.07803},
author = {Bhaskar, Dhananjay and Magruder, Sumner and De Brouwer, Edward and Venkat, Aarthi and Wenkel, Frederik and Wolf, Guy and Krishnaswamy, Smita},
keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Inferring dynamic regulatory interaction graphs from time series data with perturbations},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}
License
Yale License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ritini-0.0.2.tar.gz.
File metadata
- Download URL: ritini-0.0.2.tar.gz
- Upload date:
- Size: 54.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4ad0b74d08d9493de40f5ad783f307fdef2b728903ba1c9ae09839b65a6900a7
|
|
| MD5 |
29bee889e71c1f787b3c511d9e3a3e50
|
|
| BLAKE2b-256 |
d090dd2a4c30e15bc058d950c4a65993dc14d39e2b33aeb4bee552f9b992898b
|
File details
Details for the file ritini-0.0.2-py3-none-any.whl.
File metadata
- Download URL: ritini-0.0.2-py3-none-any.whl
- Upload date:
- Size: 77.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1ca93a12c24b3f26dd7e8b9576e526c05a168d371e00096b67edc0a9b93840a3
|
|
| MD5 |
f1749b2e2ded485d7700c4e7502be5e9
|
|
| BLAKE2b-256 |
12a521903ed45c18ef406d6e67239f31f8335235d9bda94f8387a8c4517d18b2
|