Skip to main content

Efficient Evolutionary Scale Modeling: Efficient and simplified implementation of protein language model for inference and training.

Project description

ESM-Efficient

pypi DOI:10.1101/2024.10.22.619563

Efficient implementation of the ESM family of models: ESM1b, ESM1v, ESM2, ESMC.

Installation

Download the appropriate version of pytorch and install it.

pip install flash-attn --no-build-isolation
pip install esm-efficient

Basic Usage

from esme import ESM

model = ESM.from_pretrained('esmc') # or 'esm1b', 'esm1v', 'esm2', 'esm2_8m', ...

This will download the model weights from the HuggingFace model hub and load the model. See doc from getting started.

Tokenization and Predicting Log Probabilities

Predict the log probabilities of a sequence of tokens using the model.

import torch
from esme import ESM2
from esme.alphabet import tokenize

# create load the model
model = ESM2.from_pretrained("{model}.safetensors", device=0)

tokens = tokenize(['MEEPQSDPSVEPPLSQESTFSLDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)

# predict logits
logits = model(tokens)
# logits.shape = (2, seq_len, embed_size)

# predict log probabilities
log_probs = model.predict_log_prob(tokens)
# log_probs.shape = (2, seq_len, embed_size)

Tokenization without Padding

from esme.alphabet import tokenize_unpad
# tokenize without padding (more efficient avoids calculating with padding)
tokens, indices, cu_lens, max_len = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
cu_lens = cu_lens.to(0)
log_probs = model.predict_log_prob(tokens, (cu_lens, max_len))
# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)

Predict effect of variants

from esme.variant import predict_mask_margin

seq = 'MEEPQSDPSVEPPLSQETFSDLWK'
df = predict_mask_margin(model, seq)
# ... pd.DataFrame({
# ...    'variant': ['M1A', 'M1C', ..., 'P16Y'],
# ...    'score': [-0.1, -0.2, ..., -0.3]
# ... }).set_index('variant')

Fine-tune the model with lora adapters:

# only add will be trained by default
model.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])

# mark only lora as trainable called by default when adding lora
model.mark_only_lora_as_trainable()

# save the model with the lora weights
model.save_lora('<path>.safetensors', adapter_names=['adapter1'])

# load the model with the lora weights
model.load_lora('<path>.safetensors')

Quantization of the model:

model = ESM2.from_pretrained('8M.safetensors', quantization='4bit', device=0)

Activation checkpointing of each transformer layer:

model = ESM2.from_pretrained('8M.safetensors', checkpointing=True)

Training the model

We provide pytorch lightning trainer for training the model. The following code trains the model with the masked language model objective:

from esme import ESM2
from esme.data import MaskedFastaTokenDataModule
from esme.trainer import MaskedPLM

trainer = MaskedPLM(model) # pytorch lightning trainer
datamodule = MaskedFastaTokenDataModule(
    train_fasta='train.fasta',
    val_fasta='val.fasta',
    token_per_batch=50_000,
) # data module for training
trainer.fit(datamodule) 

Model Weights

The model weights can be downloaded from the HuggingFace: https://huggingface.co/mhcelik/esm-efficient/tree/main

Evaluation

To perform the evaluation reported in the paper, run the following command:

snakemake -n --use-conda

This will download the data, train the models, and evaluate them. The results will be saved in the results directory. See the workflow/Snakefile for more details.

To generate a specific figures in the paper, run the following command:

snakemake reports/paper_figures/figure-2.pdf -n --use-conda 

Testing

Install renamed esm package for testing:

pip install git+https://github.com/MuhammedHasan/fair-esm.git
pip install esm

To run the tests, run the following command:

pytest tests/

Citation

Manuscript for the efficient implementation: https://www.biorxiv.org/content/10.1101/2024.10.22.619563v1

@article {Celik2024.10.22.619563,
    author = {Celik, Muhammed Hasan and Xie, Xiaohui},
    title = {Efficient Inference, Training, and Fine-tuning of Protein Language Models},
    elocation-id = {2024.10.22.619563},
    year = {2024},
    doi = {10.1101/2024.10.22.619563},
    publisher = {Cold Spring Harbor Laboratory},
    URL = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563},
    eprint = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563.full.pdf},
    journal = {bioRxiv}
}

Also, cite original ESM papers for the related model: https://github.com/facebookresearch/esm

LICENSE

This code implements ESM models from scratch and is licensed under the MIT License. Refer to the esm and fair-esm repositories for the licenses for the model weights.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

esm_efficient-0.0.9.tar.gz (39.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

esm_efficient-0.0.9-py3-none-any.whl (34.9 kB view details)

Uploaded Python 3

File details

Details for the file esm_efficient-0.0.9.tar.gz.

File metadata

  • Download URL: esm_efficient-0.0.9.tar.gz
  • Upload date:
  • Size: 39.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for esm_efficient-0.0.9.tar.gz
Algorithm Hash digest
SHA256 3c4518ed1a3d74bbbe3dcfda5f59c365e6f92e90218ef900a3b358828d25c370
MD5 bf522bbc3540354d75973b4f96805c77
BLAKE2b-256 642739f349b0dc044279818c5d498ef06fe5779ccc3681ce7a3f0bfac4618105

See more details on using hashes here.

File details

Details for the file esm_efficient-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: esm_efficient-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 34.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for esm_efficient-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 8c431a67f90a44e0955b8cacf63aebb9aed146f443a149fa6022256a38e6b060
MD5 e9afbbd7d3435c27611e438bf8257651
BLAKE2b-256 9644706cb82764ced52abfbc425e7bdb68293867300ea92640335c8a6c5c8c14

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page