Skip to main content

Tahoe-x1: Perturbation trained single-cell foundation models with up to 3 billion parameters

Project description

Linter: Ruff License Code style: black


Tahoe-x1: Scaling Perturbation-Trained Single-Cell Foundation Models to 3 Billion Parameters

๐Ÿ“„ Paper | ๐Ÿค— HuggingFace | ๐Ÿš€ Getting Started | ๐Ÿง‘โ€๐Ÿซ Tutorials

Tahoe-x1 (Tx1) is a family of perturbation-trained single-cell foundation models with up to 3 billion parameters, developed by Tahoe Therapeutics. Tx1 is pretrained on large-scale single-cell transcriptomic datasets including the Tahoe-100M perturbation compendium, and fine-tuned for cancer-relevant tasks. Through architectural optimizations and efficient training strategies, Tx1 achieves 3โ€“30ร— higher compute efficiency than prior implementations while delivering state-of-the-art performance across disease-relevant benchmarks.

Abstract Logo

Table of Contents

Repository Structure

This repository follows a similar structure to llm-foundry and imports several utility functions from it.

tahoe-x1/
โ”œโ”€โ”€ tahoe_x1/                  # Core Tahoe-x1 library
โ”‚   โ”œโ”€โ”€ model/
โ”‚   โ”‚   โ”œโ”€โ”€ blocks/           # Building block modules used across models
โ”‚   โ”‚   โ””โ”€โ”€ model/            # Full architecture subclassed from ComposerModel
โ”‚   โ”œโ”€โ”€ tasks/                # Helper functions for downstream tasks
โ”‚   โ”œโ”€โ”€ tokenizer/            # Vocabulary building and tokenization functions
โ”‚   โ”œโ”€โ”€ data/                 # Data loaders and collators
โ”‚   โ””โ”€โ”€ utils/                # Utility functions 
โ”œโ”€โ”€ scripts/
โ”‚   โ”œโ”€โ”€ train.py              # Training script
โ”‚   โ”œโ”€โ”€ prepare_for_inference.py  # Prepares model for inference
โ”‚   โ”œโ”€โ”€ depmap/               # DepMap benchmark scripts
โ”‚   โ”œโ”€โ”€ msigdb/               # MSigDB pathway benchmark scripts
โ”‚   โ”œโ”€โ”€ state_transition/     # State transition prediction scripts
โ”‚   โ”œโ”€โ”€ data_prep/            # Dataset preparation scripts
โ”‚   โ””โ”€โ”€ inference/            # Inference utilities
โ”œโ”€โ”€ tutorials/                 # Jupyter notebook tutorials
โ”‚   โ”œโ”€โ”€ clustering_tutorial.ipynb  # Cell clustering and UMAP visualization
โ”‚   โ””โ”€โ”€ training_tutorial.ipynb    # Training walkthrough
โ””โ”€โ”€ configs/                      
    โ”œโ”€โ”€runai/                 # RunAI configuration files
    โ”œโ”€โ”€mcli/                  # MosaicML platform configuration files
    โ”œโ”€โ”€gcloud/                # Google Cloud configuration files
    โ””โ”€โ”€test_run.yaml          # Sample config file

Installation

Docker installation provides better reproducibility and avoids dependency conflicts.

Docker Installation (Recommended)

# Clone the repository
git clone https://github.com/tahoebio/tahoe-x1.git
cd tahoe-x1

# Pull the latest Docker image with all the dependencies pre-installed
docker pull ghcr.io/tahoebio/tahoe-x1:latest

# Start an interactive container with GPU support
# Note that nvidia-container-toolkit is required for this step
docker run -it --rm \
  --gpus all \
  -v "$(pwd)":/workspace \
  -w /workspace \
  ghcr.io/tahoebio/tahoe-x1:latest\
  /bin/bash

# Inside the container, install the Tahoe-x1 package (dependencies are pre-installed)
pip install -e . --no-deps

The Docker image includes all necessary dependencies including PyTorch, CUDA drivers, and flash-attention for optimal performance.

Native Installation (Alternative)

For direct installation without Docker, we recommend using uv for dependency management:

# Clone the repository
git clone https://github.com/tahoebio/tahoe-x1.git
cd tahoe-x1

# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh

# Create and activate virtual environment
uv venv
source .venv/bin/activate

# Install the package with dependencies
uv pip install -e . --no-build-isolation-package flash-attn

Note: Native installation requires compatible CUDA drivers and may encounter dependency conflicts. Docker installation is recommended for the best experience.

System Requirements & Training Capabilities

Tahoe-x1 is built natively on Composer and llm-foundry, inheriting their full suite of large-scale training capabilities:

Hardware Requirements

  • GPU: NVIDIA Ampere (A100) or newer for FlashAttention support
  • CUDA: Version 12.1+
  • Python: 3.10+

Advanced Training Features

The codebase leverages Composer's state-of-the-art training stack, configurable via YAML:

  • Automatic micro-batching for optimal memory utilization
  • Mixed precision training with BF16/FP16, plus FP8 support on Hopper (H100) and newer GPUs
  • Multi-GPU and multi-node distributed training with FSDP (Fully Sharded Data Parallel)
  • Gradient accumulation and checkpointing for training larger models on limited hardware
  • Advanced optimizers and schedulers from the LLM training ecosystem
  • Streaming datasets for efficient data loading at scale

This infrastructure supports training models from 70M to 3B parameters and can scale to larger architectures.

Docker Images

We provide pre-built Docker images for ease of use:

Image Name Base Image Description
ghcr.io/tahoebio/tahoe-x1:1.0.0 mosaicml/llm-foundry:2.2.1_cu121_flash2-813d596 Current release image for Tahoe-x1

Datasets

Tx1 was pretrained on 266 million single-cell profiles from three major sources. The following datasets were used for training and evaluation:

Dataset Description Usage Location
CellxGene 2025-01 ~61M cells from Jan 2025 CellxGene release Tx1-3B stage 1 Pre-training s3://tahoe-hackathon-data/MFM/cellxgene_2025_01_21_merged_MDS/
scBaseCamp 2025-02 ~112M cells from Feb 2025 scBaseCamp release Tx1-3B stage 1 Pre-training s3://tahoe-hackathon-data/MFM/scbasecamp_2025_02_25_MDS_v2/
Tahoe 100M ~96M cells from Tahoe-100M Tx1-3B stage 1 Pre-training s3://tahoe-hackathon-data/MFM/tahoe_100m_MDS_v2/
filtered CellxGene 2025-01 ~43M filtered cells from Jan 2025 CellxGene release Tx1-3B stage 2 Pre-training s3://tahoe-hackathon-data/MFM/cellxgene_2025_01_21_merged_MDS_filtered/
filtered scBaseCamp 2025-02 ~76M filtered cells from Feb 2025 scBaseCamp release Tx1-3B stage 2 Pre-training s3://tahoe-hackathon-data/MFM/scbasecamp_2025_02_25_MDS_v2_filtered/
filtered Tahoe 100M ~34M filtered cells from Tahoe-100M Tx1-3B stage 2 Pre-training s3://tahoe-hackathon-data/MFM/tahoe_100m_MDS_v2_filtered/
DepMap Cancer cell line dependency data DepMap Benchmark s3://tahoe-hackathon-data/MFM/benchmarks/depmap/
MSigDB Pathway signature data MsigDB Benchmark s3://tahoe-hackathon-data/MFM/benchmarks/msigdb/

Filtered versions of the pre-training datasets above exclude cells with very few expressed genes and are used for stage 2 pre-training of Tx1-3B.

Public access to datasets: s3://tahoe-hackathon-data/MFM/benchmarks/

If you require access to datasets not available in the public bucket, please open a GitHub issue or contact the team.

For more information on dataset preparation, see scripts/data_prep/README.md.

Pre-trained Models

We provide pre-trained Tahoe-x1 models of various sizes:

Model Name Parameters Context Length Checkpoint Path WandB ID Config File
Tx1-3B 3B 2048 s3://tahoe-hackathon-data/MFM/ckpts/3b/ mygjkq5c ./configs/mcli/tahoe_x1-3b-v2-cont-train.yaml
Tx1-1.3B 1.3B 2048 s3://tahoe-hackathon-data/MFM/ckpts/1b/ 26iormxc ./configs/gcloud/tahoe_x1-1_3b-merged.yaml
Tx1-70M 70M 1024 s3://tahoe-hackathon-data/MFM/ckpts/70m/ ftb65le8 ./configs/gcloud/tahoe_x1-70m-merged.yaml

Model weights are also available as safetensor files on our ๐Ÿค— Huggingface model card.

Training and Fine-tuning

Training from Scratch

A sample test configuration is available at configs/test_run.yaml for quick experimentation.

Use the main training script with a YAML configuration file:

composer scripts/train.py -f configs/test_run.yaml

Or with command-line arguments:

composer scripts/train.py \
  --model_name tahoe_x1 \
  --data_path /path/to/data \
  --max_seq_len 2048 \
  --batch_size 32

Note that the current codebase only supports attn_impl: flash and use_attn_mask: False. The Triton backend and custom attention masks (used for training Tx1-1B and Tx1-70M) are no longer supported. If you have questions about using custom attention masks with the Triton backend, please contact us.

Fine-tuning

To fine-tune a pre-trained model on your own data:

  1. Download a pre-trained checkpoint from S3
  2. Modify the training configuration to load from checkpoint
  3. Prepare your dataset in the MDS format (see scripts/data_prep/README.md)
  4. Launch training with the --load_path argument
python scripts/train.py \
  -f configs/finetune_config.yaml \
  --load_path s3://path/to/checkpoint

Launching runs on different platforms

For launching runs on specific platforms such as MosaicML, Google Cloud, or RunAI, refer to the corresponding configuration folders under configs/ and their respective README files.

Preparing Models for Inference

Package a trained model with its vocabulary and metadata:

python scripts/prepare_for_inference.py \
  --model_path /path/to/checkpoint \
  --vocab_path /path/to/vocab.json \
  --output_path /path/to/inference_model

Generating Cell and Gene Embeddings

Quick Start with Inference Script

Extract cell embeddings from an AnnData object:

from omegaconf import OmegaConf as om
from scripts.inference.predict_embeddings import predict_embeddings

cfg = {
    "model_name": "Tx1-70m",
    "paths": {
        "hf_repo_id": "tahoebio/Tahoe-x1",
        "hf_model_size": "70m",
        "adata_input": "/path/to/your_data.h5ad",
    },
    "data": {
        "cell_type_key": "cell_type",
        "gene_id_key": "ensembl_id"
    },
    "predict": {
        "seq_len_dataset": 2048,
        "return_gene_embeddings": False,
    }
}

cfg = om.create(cfg)
adata = predict_embeddings(cfg)

# Access embeddings
cell_embeddings = adata.obsm["Tx1-70m"]

Extracting Gene Embeddings

Set return_gene_embeddings: True in the configuration to extract gene-level representations.

Tutorials and Benchmarks

Tutorials

Benchmarks

Tx1 achieves state-of-the-art performance across disease-relevant benchmarks. See our preprint for detailed results.

Benchmark Task Code Location
DepMap Essentiality Predict broad and context-specific gene dependencies scripts/depmap/
MSigDB Hallmarks Recover 50 hallmark pathway memberships from gene embeddings scripts/msigdb/
Cell-Type Classification Classify cell types across 5 tissues (Tabula Sapiens 2.0) cz-benchmarks
Perturbation Prediction Predict transcriptional responses in held-out contexts scripts/state_transition/

Additional Resources

Troubleshooting

Common Issues and Solutions

  • PyTorch/CUDA mismatch: Ensure PyTorch is installed with the correct CUDA version for your system
  • Docker permission denied: Run Docker commands with sudo or add your user to the docker group
  • OOM (Out of Memory): Ensure half-precision, flash-attention are enabled, set microbatch_size to auto
  • S3 access denied: For public buckets, the code will automatically retry with unsigned requests

For additional help, please open an issue on GitHub with:

  • Your system configuration (OS, GPU, PyTorch version)
  • Complete error message and stack trace
  • Steps to reproduce the issue

Acknowledgements

We thank the developers of the following open-source projects:


For questions, issues, or collaboration inquiries, please open an issue on GitHub or write to us at admin@tahoebio.ai.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tahoe_x1-1.0.0.tar.gz (50.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tahoe_x1-1.0.0-py3-none-any.whl (48.0 kB view details)

Uploaded Python 3

File details

Details for the file tahoe_x1-1.0.0.tar.gz.

File metadata

  • Download URL: tahoe_x1-1.0.0.tar.gz
  • Upload date:
  • Size: 50.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.17

File hashes

Hashes for tahoe_x1-1.0.0.tar.gz
Algorithm Hash digest
SHA256 0011c9d3457c7c43d7b4bf821d1b933b70d1ca98fc8d27324bff34d108de1845
MD5 edff98ed748006477a0d882ce8416aeb
BLAKE2b-256 606d455c180d3aaefef6850652e95451652216a4d8da3f06fb7cfccf4b52ee90

See more details on using hashes here.

File details

Details for the file tahoe_x1-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: tahoe_x1-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 48.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.17

File hashes

Hashes for tahoe_x1-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c217efb7b8b3b4b4f3f03bbdb0504e2cded6d7ae8f5bbd7414f295dca2f2b1ea
MD5 0301b4fb89a2d8fc1aaf9674f8d55da0
BLAKE2b-256 626bc95b99f55a66c7e69d996949f93dd23e525811a07f1c3d1291e8380ccb9e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page