Skip to main content

Training and inference on protein sets (genomes)

Project description

Protein Set Transformer

This repository contains the Protein Set Transformer (PST) framework for contextualizing protein language model embeddings at genome-scale to produce genome embeddings. You can use this code to train your own models. Using our foundation model pre-trained on viruses (vPST), you can also generate genome embeddings for input viruses.

For more information, see our manuscript:

Protein Set Transformer: A protein-based genome language model to power high diversity viromics.
Cody Martin, Anthony Gitter, and Karthik Anantharaman. bioRxiv, 2024, doi: 10.1101/2024.07.26.605391.

Installation

You can try simply doing:

pip install ptn-set-transformer

But I prefer to manually setup the PyTorch installation to control CPU/GPU availability.

This full installation can be achieved with mamba and pip, which should take no more than 5 minutes.

Note: you will likely need to link your git command line interface with an online github account. Follow this link for help setting up git at the command line.

Without GPUs

# setup torch first -- conda does this so much better than pip
mamba create -n pst -c pytorch -c pyg -c conda-forge 'python<3.12' 'pytorch>=2.0' cpuonly pyg pytorch-scatter

mamba activate pst
pip install ptn-set-transformer

With GPUs

# setup torch first -- conda does this so much better than pip
mamba create -n pst -c pytorch -c nvidia -c pyg -c conda-forge 'python<3.12' 'pytorch>=2.0' pytorch-cuda=11.8 pyg pytorch-scatter

mamba activate pst
pip install ptn-set-transformer

Installing for training a new PST

We implemented a hyperparameter tuning cross validation workflow implemented using Lightning Fabric in a base library called lightning-crossval. Part of our specific implementation for hyperparameter tuning is also implemented in the PST library.

If you want to include the optional dependendings for training a new PST, you can follow the corresponding installation steps above with the following change:

pip install .[tune]

Test run

Upon successful installation, you will have the pst executable to train, tune, and predict. There are also other modules included as utilties that you can see using pst -h.

You will need to first download a trained vPST model:

pst download --trained-models

This will download both vPST models into ./pstdata, but you can change the download location using --outdir.

You can use the test data for a test prediction run:

pst predict \
    --file test/test_data.graphfmt.h5 \ # this is in the git repo
    --checkpoint pstdata/pst-small_trained_model.ckpt \
    --outdir test_run

The results from the above command are available at test/test_run/predictions.h5. This test run takes fewer than 1 minute using a single CPU.

If you are unfamiliar with .h5 files, you can use pytables (installed with PST as a dependency) to inspect .h5 files in python, or you can install hdf5 and use the h5ls to inspect the fields in the output file.

There should be 3 fields in the prediciton file:

  1. attn which contains the per-protein attention values (shape: $N_{prot} \times N_{heads}$)
  2. ctx_ptn which contains the contextualized PST protein embeddings (shape: $N_{prot} \times D$)
  3. genome which contains the PST genome embeddings (shape: $N_{genome} \times D$)
    • Prior to version 1.2.0, this was called data.

Data availability

All data associated with the initial training model training can be found here: https://doi.org/10.5061/dryad.d7wm37q8w

We have provided the README to the DRYAD data repository to render here. Additionally, we have provided a programmatic way to access the data from the command line using pst download:

NOTE: we have recently changed the DRYAD repository corresponding to manuscript resubmission, so these commands will not work at the moment. However, the latest dataset will be available to download directly through DRYAD soon.

usage: pst download [-h] [--all] [--outdir PATH] [--esm-large] [--esm-small] [--vpst-large] [--vpst-small] [--genome] [--genslm]
                    [--trained-models] [--genome-clusters] [--protein-clusters] [--aai] [--fasta] [--host-prediction] [--no-readme]
                    [--supplementary-data] [--supplementary-tables]

help:
  -h, --help            show this help message and exit

DOWNLOAD:
  --all                 download all files from the DRYAD repository (default: False)
  --outdir PATH         output directory to save files (default: ./pstdata)

EMBEDDINGS:
  --esm-large           download ESM2 large [t33_150M] PROTEIN embeddings for training and test viruses (esm-large_protein_embeddings.tar.gz)
                        (default: False)
  --esm-small           download ESM2 small [t6_8M] PROTEIN embeddings for training and test viruses (esm-small_protein_embeddings.tar.gz)
                        (default: False)
  --vpst-large          download vPST large PROTEIN embeddings for training and test viruses (pst-large_protein_embeddings.tar.gz) (default:
                        False)
  --vpst-small          download vPST small PROTEIN embeddings for training and test viruses (pst-small_protein_embeddings.tar.gz) (default:
                        False)
  --genome              download all genome embeddings for training and test viruses (genome_embeddings.tar.gz) (default: False)
  --genslm              download GenSLM ORF embeddings (genslm_protein_embeddings.tar.gz) (default: False)

TRAINED_MODELS:
  --trained-models      download trained vPST models (trained_models.tar.gz) (default: False)

CLUSTERS:
  --genome-clusters     download genome cluster labels (genome_clusters.tar.gz) (default: False)
  --protein-clusters    download protein cluster labels (protein_clusters.tar.gz) (default: False)

MANUSCRIPT_DATA:
  --aai                 download intermediate files for AAI calculations in the manuscript (aai.tar.gz) (default: False)
  --fasta               download protein fasta files for training and test viruses (fasta.tar.gz) (default: False)
  --host-prediction     download all data associated with the host prediction proof of concept (host_prediction.tar.gz) (default: False)
  --no-readme           download the DRYAD README (README.md) (default: True)
  --supplementary-data  download supplementary data directly used to make the figures in the manuscript (supplementary_data.tar.gz) (default:
                        False)
  --supplementary-tables
                        download supplementary tables (supplementary_tables.zip) (default: False)

For flags relating to the download of specific files, you can add as many flags as you like.

Model information

Specifically at DRYAD link, trained_models.tar.gz contains both sizes of the vPST foundation model, pst-small and pst-large. Each model was trained with the same input data.

The training and test data are also available in the above data repository.

Here is a summary of each model:

Model # Encoder layers # Attention heads # Params Embedding dim
pst-small 5 4 5.4M 400
pst-large 20 32 177.9M 1280

Usage, Finetuning, and Model API

Please read the wiki for more information about how to use these models, extend them for finetuning and transfer learning, and the specific model API to integrate new models into your own workflows. Note: This is still a work in progress.

Expected runtime and memory consumption

The expected runtime for training the final models after hyperparameter tuning can be found in Supplementary Table 11 and ranged from 3.9-33.7h on 1 A100 GPU.

Inference times

These are estimates of inference times for a dataset composed of ~12k viral genomes encoding ~140k proteins (such as the MGnify test dataset):

Model Size Accelerator ESM2 embedding* PST inference Total Time
Large 1 A100 GPU 18 min <1 min ~19 min
Large 128 CPUs 6h <1 min ~6h
Large 8 CPUs 96h 11 min ~96h
Small 1 A100 GPU 9 min <1 min ~9 min
Small 128 CPUs 3h <1 min ~3h
Small 8 CPUs 48h 6 min ~48h

* ESM2 embeddings are computed independently for each protein, so input FASTA files can be split into equal batches and processed in parallel with as many GPUs as available.

  • These will need to be concatenated in the same order as the original FASTA file.

Memory

Memory usage should be negligible for inference, especially if using a LazyGenomeDataset. Less than 4GB of memory is needed for inference.

Manuscript

We have provided code for all analyses associated with the manuscript in the manuscript folder. The README in that folder links each method section from the manuscript to a specific Jupyter notebook code implementation.

Associated repositories

There are several other repositories associated with the model code and the manuscript:

Repository Description
lightning-crossval Our fold-synchronized cross validation strategy implemented with Lightning Fabric
esm_embed Our user-friendly way of embedding proteins from a FASTA file with ESM2 models
genslm_embed Code to generate GenSLM ORF and genome embeddings
hyena-dna-embed Code to generate Hyena-DNA genome embeddings
PST_host_prediction Model and evaluation code for our host prediction proof of concept analysis

Citation

Please cite our preprint if you find our work useful:

Martin C, Gitter A, Anantharaman K. (2024) "Protein Set Transformer: A protein-based genome language model to power high diversity viromics."

@article {
  author = {Cody Martin and Anthony Gitter and Karthik Anantharaman},
  title = {Protein Set Transformer: A protein-based genome language model to power high diversity viromics},
  elocation-id = {2024.07.26.605391},
  year = {2024},
  doi = {10.1101/2024.07.26.605391},
  publisher = {Cold Spring Harbor Laboratory},
  URL = {https://www.biorxiv.org/content/10.1101/2024.07.26.605391v1},
  eprint = {https://www.biorxiv.org/content/10.1101/2024.07.26.605391v1.full.pdf}
  journal = {bioRxiv},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ptn_set_transformer-2.6.0.tar.gz (98.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ptn_set_transformer-2.6.0-py3-none-any.whl (113.5 kB view details)

Uploaded Python 3

File details

Details for the file ptn_set_transformer-2.6.0.tar.gz.

File metadata

  • Download URL: ptn_set_transformer-2.6.0.tar.gz
  • Upload date:
  • Size: 98.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for ptn_set_transformer-2.6.0.tar.gz
Algorithm Hash digest
SHA256 6ca61ce6cbac69f73b03e45abf7a497f0488b0a16f8b02f4195261c37b1be322
MD5 1d69a0914bda0d035ee5cefa93961626
BLAKE2b-256 7c4d2e03003c1384430392b3d21a7ebf6407625103c6156bfa71bc172b8a3059

See more details on using hashes here.

File details

Details for the file ptn_set_transformer-2.6.0-py3-none-any.whl.

File metadata

File hashes

Hashes for ptn_set_transformer-2.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f1672eb53eaeca184b93af334f019335536a813e129fa2b29a64540680a0e37b
MD5 5e84bbbcbef8a527c5d8a6d25cb8ff30
BLAKE2b-256 898cf99b2e322fa627fdbc11125571385ef8f6f322aef6012098ad924d407ac2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page