Tahoe-x1: Perturbation trained single-cell foundation models with up to 3 billion parameters
Project description
Tahoe-x1: Scaling Perturbation-Trained Single-Cell Foundation Models to 3 Billion Parameters
๐ Paper | ๐ค HuggingFace | ๐ Getting Started | ๐งโ๐ซ Tutorials
Tahoe-x1 (Tx1) is a family of perturbation-trained single-cell foundation models with up to 3 billion parameters, developed by Tahoe Therapeutics. Tx1 is pretrained on large-scale single-cell transcriptomic datasets including the Tahoe-100M perturbation compendium, and fine-tuned for cancer-relevant tasks. Through architectural optimizations and efficient training strategies, Tx1 achieves 3โ30ร higher compute efficiency than prior implementations while delivering state-of-the-art performance across disease-relevant benchmarks.
Table of Contents
- Repository Structure
- Installation
- Training Infrastructure
- Datasets
- Pre-trained Models
- Training and Fine-tuning
- Generating Cell and Gene Embeddings
- Tutorials and Benchmarks
- Developer Guidelines
- Acknowledgements
- License
Repository Structure
This repository follows a similar structure to llm-foundry and imports several utility functions from it.
tahoe-x1/
โโโ tahoe_x1/ # Core Tahoe-x1 library
โ โโโ model/
โ โ โโโ blocks/ # Building block modules used across models
โ โ โโโ model/ # Full architecture subclassed from ComposerModel
โ โโโ tasks/ # Helper functions for downstream tasks
โ โโโ tokenizer/ # Vocabulary building and tokenization functions
โ โโโ data/ # Data loaders and collators
โ โโโ utils/ # Utility functions
โโโ scripts/
โ โโโ train.py # Training script
โ โโโ prepare_for_inference.py # Prepares model for inference
โ โโโ depmap/ # DepMap benchmark scripts
โ โโโ msigdb/ # MSigDB pathway benchmark scripts
โ โโโ state_transition/ # State transition prediction scripts
โ โโโ data_prep/ # Dataset preparation scripts
โ โโโ inference/ # Inference utilities
โโโ tutorials/ # Jupyter notebook tutorials
โ โโโ clustering_tutorial.ipynb # Cell clustering and UMAP visualization
โ โโโ training_tutorial.ipynb # Training walkthrough
โโโ configs/
โโโrunai/ # RunAI configuration files
โโโmcli/ # MosaicML platform configuration files
โโโgcloud/ # Google Cloud configuration files
โโโtest_run.yaml # Sample config file
Installation
Docker installation provides better reproducibility and avoids dependency conflicts.
Docker Installation (Recommended)
# Clone the repository
git clone https://github.com/tahoebio/tahoe-x1.git
cd tahoe-x1
# Pull the latest Docker image with all the dependencies pre-installed
docker pull ghcr.io/tahoebio/tahoe-x1:latest
# Start an interactive container with GPU support
# Note that nvidia-container-toolkit is required for this step
docker run -it --rm \
--gpus all \
-v "$(pwd)":/workspace \
-w /workspace \
ghcr.io/tahoebio/tahoe-x1:latest\
/bin/bash
# Inside the container, install the Tahoe-x1 package (dependencies are pre-installed)
pip install -e . --no-deps
The Docker image includes all necessary dependencies including PyTorch, CUDA drivers, and flash-attention for optimal performance.
Native Installation (Alternative)
For direct installation without Docker, we recommend using uv for dependency management:
# Clone the repository
git clone https://github.com/tahoebio/tahoe-x1.git
cd tahoe-x1
# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh
# Create and activate virtual environment
uv venv
source .venv/bin/activate
# Install the package with dependencies
uv pip install -e . --no-build-isolation-package flash-attn
Note: Native installation requires compatible CUDA drivers and may encounter dependency conflicts. Docker installation is recommended for the best experience.
System Requirements & Training Capabilities
Tahoe-x1 is built natively on Composer and llm-foundry, inheriting their full suite of large-scale training capabilities:
Hardware Requirements
- GPU: NVIDIA Ampere (A100) or newer for FlashAttention support
- CUDA: Version 12.1+
- Python: 3.10+
Advanced Training Features
The codebase leverages Composer's state-of-the-art training stack, configurable via YAML:
- Automatic micro-batching for optimal memory utilization
- Mixed precision training with BF16/FP16, plus FP8 support on Hopper (H100) and newer GPUs
- Multi-GPU and multi-node distributed training with FSDP (Fully Sharded Data Parallel)
- Gradient accumulation and checkpointing for training larger models on limited hardware
- Advanced optimizers and schedulers from the LLM training ecosystem
- Streaming datasets for efficient data loading at scale
This infrastructure supports training models from 70M to 3B parameters and can scale to larger architectures.
Docker Images
We provide pre-built Docker images for ease of use:
| Image Name | Base Image | Description |
|---|---|---|
ghcr.io/tahoebio/tahoe-x1:1.0.0 |
mosaicml/llm-foundry:2.2.1_cu121_flash2-813d596 |
Current release image for Tahoe-x1 |
Datasets
Tx1 was pretrained on 266 million single-cell profiles from three major sources. The following datasets were used for training and evaluation:
| Dataset | Description | Usage | Location |
|---|---|---|---|
| CellxGene 2025-01 | ~61M cells from Jan 2025 CellxGene release | Tx1-3B stage 1 Pre-training | s3://tahoe-hackathon-data/MFM/cellxgene_2025_01_21_merged_MDS/ |
| scBaseCamp 2025-02 | ~112M cells from Feb 2025 scBaseCamp release | Tx1-3B stage 1 Pre-training | s3://tahoe-hackathon-data/MFM/scbasecamp_2025_02_25_MDS_v2/ |
| Tahoe 100M | ~96M cells from Tahoe-100M | Tx1-3B stage 1 Pre-training | s3://tahoe-hackathon-data/MFM/tahoe_100m_MDS_v2/ |
| filtered CellxGene 2025-01 | ~43M filtered cells from Jan 2025 CellxGene release | Tx1-3B stage 2 Pre-training | s3://tahoe-hackathon-data/MFM/cellxgene_2025_01_21_merged_MDS_filtered/ |
| filtered scBaseCamp 2025-02 | ~76M filtered cells from Feb 2025 scBaseCamp release | Tx1-3B stage 2 Pre-training | s3://tahoe-hackathon-data/MFM/scbasecamp_2025_02_25_MDS_v2_filtered/ |
| filtered Tahoe 100M | ~34M filtered cells from Tahoe-100M | Tx1-3B stage 2 Pre-training | s3://tahoe-hackathon-data/MFM/tahoe_100m_MDS_v2_filtered/ |
| DepMap | Cancer cell line dependency data | DepMap Benchmark | s3://tahoe-hackathon-data/MFM/benchmarks/depmap/ |
| MSigDB | Pathway signature data | MsigDB Benchmark | s3://tahoe-hackathon-data/MFM/benchmarks/msigdb/ |
Filtered versions of the pre-training datasets above exclude cells with very few expressed genes and are used for stage 2 pre-training of Tx1-3B.
Public access to datasets: s3://tahoe-hackathon-data/MFM/benchmarks/
If you require access to datasets not available in the public bucket, please open a GitHub issue or contact the team.
For more information on dataset preparation, see scripts/data_prep/README.md.
Pre-trained Models
We provide pre-trained Tahoe-x1 models of various sizes:
| Model Name | Parameters | Context Length | Checkpoint Path | WandB ID | Config File |
|---|---|---|---|---|---|
| Tx1-3B | 3B | 2048 | s3://tahoe-hackathon-data/MFM/ckpts/3b/ |
mygjkq5c | ./configs/mcli/tahoe_x1-3b-v2-cont-train.yaml |
| Tx1-1.3B | 1.3B | 2048 | s3://tahoe-hackathon-data/MFM/ckpts/1b/ |
26iormxc | ./configs/gcloud/tahoe_x1-1_3b-merged.yaml |
| Tx1-70M | 70M | 1024 | s3://tahoe-hackathon-data/MFM/ckpts/70m/ |
ftb65le8 | ./configs/gcloud/tahoe_x1-70m-merged.yaml |
Model weights are also available as safetensor files on our ๐ค Huggingface model card.
Training and Fine-tuning
Training from Scratch
A sample test configuration is available at configs/test_run.yaml for quick experimentation.
Use the main training script with a YAML configuration file:
composer scripts/train.py -f configs/test_run.yaml
Or with command-line arguments:
composer scripts/train.py \
--model_name tahoe_x1 \
--data_path /path/to/data \
--max_seq_len 2048 \
--batch_size 32
Note that the current codebase only supports attn_impl: flash and use_attn_mask: False. The Triton backend and custom attention masks (used for training Tx1-1B and Tx1-70M) are no longer supported. If you have questions about using custom attention masks with the Triton backend, please contact us.
Fine-tuning
To fine-tune a pre-trained model on your own data:
- Download a pre-trained checkpoint from S3
- Modify the training configuration to load from checkpoint
- Prepare your dataset in the MDS format (see scripts/data_prep/README.md)
- Launch training with the
--load_pathargument
python scripts/train.py \
-f configs/finetune_config.yaml \
--load_path s3://path/to/checkpoint
Launching runs on different platforms
For launching runs on specific platforms such as MosaicML, Google Cloud, or RunAI, refer to the corresponding configuration folders under configs/ and their respective README files.
Preparing Models for Inference
Package a trained model with its vocabulary and metadata:
python scripts/prepare_for_inference.py \
--model_path /path/to/checkpoint \
--vocab_path /path/to/vocab.json \
--output_path /path/to/inference_model
Generating Cell and Gene Embeddings
Quick Start with Inference Script
Extract cell embeddings from an AnnData object:
from omegaconf import OmegaConf as om
from scripts.inference.predict_embeddings import predict_embeddings
cfg = {
"model_name": "Tx1-70m",
"paths": {
"hf_repo_id": "tahoebio/Tahoe-x1",
"hf_model_size": "70m",
"adata_input": "/path/to/your_data.h5ad",
},
"data": {
"cell_type_key": "cell_type",
"gene_id_key": "ensembl_id"
},
"predict": {
"seq_len_dataset": 2048,
"return_gene_embeddings": False,
}
}
cfg = om.create(cfg)
adata = predict_embeddings(cfg)
# Access embeddings
cell_embeddings = adata.obsm["Tx1-70m"]
Extracting Gene Embeddings
Set return_gene_embeddings: True in the configuration to extract gene-level representations.
Tutorials and Benchmarks
Tutorials
- Clustering Tutorial: Cell clustering and UMAP visualization
- Training Tutorial: Step-by-step guide to training Tahoe-x1 models
Benchmarks
Tx1 achieves state-of-the-art performance across disease-relevant benchmarks. See our preprint for detailed results.
| Benchmark | Task | Code Location |
|---|---|---|
| DepMap Essentiality | Predict broad and context-specific gene dependencies | scripts/depmap/ |
| MSigDB Hallmarks | Recover 50 hallmark pathway memberships from gene embeddings | scripts/msigdb/ |
| Cell-Type Classification | Classify cell types across 5 tissues (Tabula Sapiens 2.0) | cz-benchmarks |
| Perturbation Prediction | Predict transcriptional responses in held-out contexts | scripts/state_transition/ |
Additional Resources
- Data Preparation: scripts/data_prep/README.md
- Platform Usage: mcli/README.md and gcloud/README.md
Troubleshooting
Common Issues and Solutions
- PyTorch/CUDA mismatch: Ensure PyTorch is installed with the correct CUDA version for your system
- Docker permission denied: Run Docker commands with
sudoor add your user to the docker group - OOM (Out of Memory): Ensure half-precision, flash-attention are enabled, set microbatch_size to auto
- S3 access denied: For public buckets, the code will automatically retry with unsigned requests
For additional help, please open an issue on GitHub with:
- Your system configuration (OS, GPU, PyTorch version)
- Complete error message and stack trace
- Steps to reproduce the issue
Acknowledgements
We thank the developers of the following open-source projects:
- scGPT: For inspiring the Tahoe-x1 architecture
- llm-foundry: Efficient training infrastructure for large language models
- streaming: Fast, efficient dataset streaming
- CZ CELLxGENE: Chan Zuckerberg Initiative's single-cell atlas
- Arc scBaseCount: Arc Institute's virtual cell atlas
- Arc Institute STATE: State Transition model for perturbation prediction
For questions, issues, or collaboration inquiries, please open an issue on GitHub or write to us at admin@tahoebio.ai.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tahoe_x1-1.0.0.tar.gz.
File metadata
- Download URL: tahoe_x1-1.0.0.tar.gz
- Upload date:
- Size: 50.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0011c9d3457c7c43d7b4bf821d1b933b70d1ca98fc8d27324bff34d108de1845
|
|
| MD5 |
edff98ed748006477a0d882ce8416aeb
|
|
| BLAKE2b-256 |
606d455c180d3aaefef6850652e95451652216a4d8da3f06fb7cfccf4b52ee90
|
File details
Details for the file tahoe_x1-1.0.0-py3-none-any.whl.
File metadata
- Download URL: tahoe_x1-1.0.0-py3-none-any.whl
- Upload date:
- Size: 48.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c217efb7b8b3b4b4f3f03bbdb0504e2cded6d7ae8f5bbd7414f295dca2f2b1ea
|
|
| MD5 |
0301b4fb89a2d8fc1aaf9674f8d55da0
|
|
| BLAKE2b-256 |
626bc95b99f55a66c7e69d996949f93dd23e525811a07f1c3d1291e8380ccb9e
|