Efficient Evolutionary Scale Modeling: Efficient and simplified implementation of protein language model for inference and training.
Project description
ESM-Efficient
Efficient implementation of the ESM family of models: ESM1b, ESM1v, ESM2, ESMC.
Installation
Download the appropriate version of pytorch and install it.
pip install flash-attn --no-build-isolation
pip install esm-efficient
Basic Usage
from esme import ESM
model = ESM.from_pretrained('esmc') # or 'esm1b', 'esm1v', 'esm2', 'esm2_8m', ...
This will download the model weights from the HuggingFace model hub and load the model. See doc from getting started.
Tokenization and Predicting Log Probabilities
Predict the log probabilities of a sequence of tokens using the model.
import torch
from esme import ESM2
from esme.alphabet import tokenize
# create load the model
model = ESM2.from_pretrained("{model}.safetensors", device=0)
tokens = tokenize(['MEEPQSDPSVEPPLSQESTFSLDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
# predict logits
logits = model(tokens)
# logits.shape = (2, seq_len, embed_size)
# predict log probabilities
log_probs = model.predict_log_prob(tokens)
# log_probs.shape = (2, seq_len, embed_size)
Tokenization without Padding
from esme.alphabet import tokenize_unpad
# tokenize without padding (more efficient avoids calculating with padding)
tokens, indices, cu_lens, max_len = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
cu_lens = cu_lens.to(0)
log_probs = model.predict_log_prob(tokens, (cu_lens, max_len))
# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)
Predict effect of variants
from esme.variant import predict_mask_margin
seq = 'MEEPQSDPSVEPPLSQETFSDLWK'
df = predict_mask_margin(model, seq)
# ... pd.DataFrame({
# ... 'variant': ['M1A', 'M1C', ..., 'P16Y'],
# ... 'score': [-0.1, -0.2, ..., -0.3]
# ... }).set_index('variant')
Fine-tune the model with lora adapters:
# only add will be trained by default
model.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])
# mark only lora as trainable called by default when adding lora
model.mark_only_lora_as_trainable()
# save the model with the lora weights
model.save_lora('<path>.safetensors', adapter_names=['adapter1'])
# load the model with the lora weights
model.load_lora('<path>.safetensors')
Quantization of the model:
model = ESM2.from_pretrained('8M.safetensors', quantization='4bit', device=0)
Activation checkpointing of each transformer layer:
model = ESM2.from_pretrained('8M.safetensors', checkpointing=True)
Training the model
We provide pytorch lightning trainer for training the model. The following code trains the model with the masked language model objective:
from esme import ESM2
from esme.data import MaskedFastaTokenDataModule
from esme.trainer import MaskedPLM
trainer = MaskedPLM(model) # pytorch lightning trainer
datamodule = MaskedFastaTokenDataModule(
train_fasta='train.fasta',
val_fasta='val.fasta',
token_per_batch=50_000,
) # data module for training
trainer.fit(datamodule)
Model Weights
The model weights can be downloaded from the HuggingFace: https://huggingface.co/mhcelik/esm-efficient/tree/main
Evaluation
To perform the evaluation reported in the paper, run the following command:
snakemake -n --use-conda
This will download the data, train the models, and evaluate them. The results will be saved in the results directory.
See the workflow/Snakefile for more details.
To generate a specific figures in the paper, run the following command:
snakemake reports/paper_figures/figure-2.pdf -n --use-conda
Testing
Install renamed esm package for testing:
pip install git+https://github.com/MuhammedHasan/fair-esm.git
pip install esm
To run the tests, run the following command:
pytest tests/
Citation
Manuscript for the efficient implementation: https://www.biorxiv.org/content/10.1101/2024.10.22.619563v1
@article {Celik2024.10.22.619563,
author = {Celik, Muhammed Hasan and Xie, Xiaohui},
title = {Efficient Inference, Training, and Fine-tuning of Protein Language Models},
elocation-id = {2024.10.22.619563},
year = {2024},
doi = {10.1101/2024.10.22.619563},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563},
eprint = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563.full.pdf},
journal = {bioRxiv}
}
Also, cite original ESM papers for the related model: https://github.com/facebookresearch/esm
LICENSE
This code implements ESM models from scratch and is licensed under the MIT License. Refer to the esm and fair-esm repositories for the licenses for the model weights.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file esm_efficient-0.0.9.tar.gz.
File metadata
- Download URL: esm_efficient-0.0.9.tar.gz
- Upload date:
- Size: 39.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3c4518ed1a3d74bbbe3dcfda5f59c365e6f92e90218ef900a3b358828d25c370
|
|
| MD5 |
bf522bbc3540354d75973b4f96805c77
|
|
| BLAKE2b-256 |
642739f349b0dc044279818c5d498ef06fe5779ccc3681ce7a3f0bfac4618105
|
File details
Details for the file esm_efficient-0.0.9-py3-none-any.whl.
File metadata
- Download URL: esm_efficient-0.0.9-py3-none-any.whl
- Upload date:
- Size: 34.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8c431a67f90a44e0955b8cacf63aebb9aed146f443a149fa6022256a38e6b060
|
|
| MD5 |
e9afbbd7d3435c27611e438bf8257651
|
|
| BLAKE2b-256 |
9644706cb82764ced52abfbc425e7bdb68293867300ea92640335c8a6c5c8c14
|