Skip to main content

A transformer model for gene expression data

Project description

TranscriptFormer

Description

Transcriptformer is a deep learning model for cross-species single-cell RNA sequencing analysis. It uses transformer-based architectures to learn representations of gene expression data across multiple species, leveraging protein sequence information through ESM-2 embeddings.

Installation

Transcriptformer requires Python >=3.11.

Install from source with uv

# Clone the repository
git clone https://github.com/czi-ai/transcriptformer.git
cd transcriptformer

# Create and activate a virtual environment with Python 3.11
uv venv --python=3.11
source .venv/bin/activate  # On Windows: .venv\Scripts\activate

# Install in development mode
uv pip install -e .

Install from PyPI with uv

# Create and activate a virtual environment
uv venv --python=3.11
source .venv/bin/activate  # On Windows: .venv\Scripts\activate

# Install from PyPI
uv pip install transcriptformer

Requirements

Transcriptformer has the following core dependencies:

  • PyTorch (<=2.5.1, as 2.6.0+ may cause pickle errors)
  • PyTorch Lightning
  • anndata
  • scanpy
  • numpy
  • pandas
  • h5py
  • hydra-core

See the pyproject.toml file for the complete list of dependencies.

Hardware Requirements

You'll need a Python environment with GPU access to run this model. While we've tested it on NVIDIA A100 GPUs, you can use smaller GPUs like the T4 for the smaller version of the model.

Downloading Model Weights

Model weights and artifacts are available via AWS S3. You can download them using the provided download_artifacts.py script:

# Download a specific model
python download_artifacts.py tf-sapiens
python download_artifacts.py tf-exemplar
python download_artifacts.py tf-metazoa

# Download all models and embeddings
python download_artifacts.py all

# Download only the embedding files
python download_artifacts.py all-embeddings

# Specify a custom checkpoint directory
python download_artifacts.py tf-sapiens --checkpoint-dir /path/to/custom/dir

The script will download and extract the following files to the ./checkpoints directory (or your specified directory):

  • ./checkpoints/tf_sapiens/: Sapiens model weights
  • ./checkpoints/tf_exemplar/: Exemplar model weights
  • ./checkpoints/tf_metazoa/: Metazoa model weights
  • ./checkpoints/all_embeddings/: Embedding files for out-of-distribution species

The script includes progress bars for both download and extraction processes.

Running Inference

The inference.py script provides a convenient interface for running inference with TranscriptFormer. The script uses Hydra for configuration management, allowing flexible parameter specification.

Basic usage:

python inference.py --config-name=inference_config.yaml model.checkpoint_path=./checkpoints/tf_sapiens

Key Parameters:

  • model.checkpoint_path: Path to the checkpoint directory containing model weights and vocabulary files
  • model.inference_config.data_files: Path(s) to input data files (H5AD format)
  • model.inference_config.pretrained_embedding: Path(s) to pretrained embeddings (out-of-distribution species)
  • model.inference_config.output_path: Directory to save inference results
  • model.inference_config.batch_size: Batch size for inference (default: 32)
  • model.inference_config.precision: Numerical precision (default: "16-mixed")

Example:

For in-distribution species (e.g. human with TF-Sapiens):

# Inference on in-distribution species
python inference.py --config-name=inference_config.yaml \
  model.checkpoint_path=./checkpoints/tf_sapiens \
  model.inference_config.data_files.0=test/data/human_val.h5ad \
  model.inference_config.batch_size=8

For out-of-distribution species (e.g. mouse with TF-Sapiens) supply the embedding file:

# Inference on out-of-distribution species
python inference.py --config-name=inference_config.yaml \
  model.checkpoint_path=./checkpoints/tf_sapiens \
  model.inference_config.data_files.0=test/data/mouse_val.h5ad \
  model.inference_config.pretrained_embedding=./checkpoints/all_embeddings/mus_musculus_gene.h5
  model.inference_config.batch_size=8

To specify multiple input files, use indexed notation:

python inference.py --config-name=inference_config.yaml \
  model.checkpoint_path=./checkpoints/tf_sapiens \
  model.inference_config.data_files.0=test/data/human_val.h5ad \
  model.inference_config.data_files.1=test/data/mouse_val.h5ad

Or use the list notation:

python inference.py --config-name=inference_config.yaml \
  model.checkpoint_path=../checkpoints/tf_sapiens \
  "model.inference_config.data_files=[test/data/human_val.h5ad,test/data/mouse_val.h5ad]"

Input Data Format:

Input data files should be in H5AD format (AnnData objects) with the following requirements:

  • Gene IDs: The var dataframe must contain an ensembl_id column with Ensembl gene identifiers

    • Out-of-vocabulary gene IDs will be automatically filtered out during processing
    • Only genes present in the model's vocabulary will be used for inference
  • Expression Data: Raw count data should be stored in the adata.X matrix

    • The model expects raw (non-normalized) counts
    • Log-transformed or normalized data may lead to unexpected results
  • Cell Metadata: Any cell metadata in the obs dataframe will be preserved in the output

Output Format:

The inference results will be saved to the specified output directory (default: ./inference_results) in a file named embeddings.h5ad. This is an AnnData object where:

  • Cell embeddings are stored in obsm['embeddings']
  • Original cell metadata is preserved in the obs dataframe
  • Log-likelihood scores (if available) are stored in uns['llh']

Output:

The inference results will be saved to the specified output directory. The script will generate:

  1. Gene embeddings for each cell
  2. Log-likelihood scores
  3. Metadata from the original dataset

Results are saved in HDF5 format with the same structure as the input data, with additional embedding matrices and likelihood scores.

For detailed configuration options, see the conf/inference_config.yaml file.

Contributing

This project adheres to the Contributor Covenant code of conduct. By participating, you are expected to uphold this code. Please report unacceptable behavior to opensource@chanzuckerberg.com.

Reporting Security Issues

Please note: If you believe you have found a security issue, please responsibly disclose by contacting us at security@chanzuckerberg.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transcriptformer-0.1.0.tar.gz (57.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transcriptformer-0.1.0-py3-none-any.whl (31.9 kB view details)

Uploaded Python 3

File details

Details for the file transcriptformer-0.1.0.tar.gz.

File metadata

  • Download URL: transcriptformer-0.1.0.tar.gz
  • Upload date:
  • Size: 57.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for transcriptformer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f065b8fe89969c0a02f09f78a764db6c08696ec35205c93894aa54b67b4cd816
MD5 7c597cfbe498e21d7f1f5c34170e91be
BLAKE2b-256 b34f074b819beb25a109df6111818ef5d97b7f02e69db0a28712a87083f335f2

See more details on using hashes here.

Provenance

The following attestation bundles were made for transcriptformer-0.1.0.tar.gz:

Publisher: publish-pypi.yml on czi-ai/transcriptformer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file transcriptformer-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for transcriptformer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b78136c1b2cb370b3b82d9e01ce30583f91c4b90e5a0557f5cf234eea8ea9cb9
MD5 7135bb184883c33371e3ce93c7d4f306
BLAKE2b-256 f8f90572629d5601766d7535e20e8d7a55cfdc791e86d625c91df9253342ccfe

See more details on using hashes here.

Provenance

The following attestation bundles were made for transcriptformer-0.1.0-py3-none-any.whl:

Publisher: publish-pypi.yml on czi-ai/transcriptformer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page