Skip to main content

No project description provided

Project description

FAESM: A Drop-in Efficient Pytorch Implementation of ESM2

Flash Attention ESM (FAESM) is an efficient PyTorch implementation of the Evolutionary Scale Modeling (ESM) family, which is a family of protein language models (pLMs) that can be used for various protein sequence analysis tasks. FAESM is designed to be more efficient than the official ESM implementation, which can save up to 60% of memory usage and 70% of inference time. The key features of FAESM are:

  1. Flash Attention: FAESM uses the FlashAttention implementation, by far the most efficient implementation of the self-attention mechanism.
  2. Scalar Dot-Product Attention (SDPA): FAESM also provides an implementation of the PyTorch Scalar Dot-Product Attention, which is a bit slower than the FlashAttention but it's compatible with most of the system and still faster than the official ESM implementation.
  3. Same Checkpoint: FAESM is a drop-in replacement of ESM2, having the same API and checkpoint.
Figure

Table of Contents

Installation

  1. Install PyTorch 1.12 and above if you haven't: pip install pytorch.

  2. [Optional]: Install flash-attn if you want to use the flash attention implementation, which is the fastest and most efficient implementation. However, it can be a bit tricky to install so you can skip this step without any problem. In that case, skip this step and you will use Pytorch SDPA attention.

pip install flash-attn --no-build-isolation

Having trouble installing flash attention but still want to use it? A workaround is docker container. You can use the official nvidia pytorch containers which have all the dependencies for flash attention.

  1. Install FAESM from github:
pip install git+https://github.com/pengzhangzhi/faesm.git

Usage

FAESM is a drop-in replacement for the official ESM implementation. You can use the same code as you would use the official ESM implementation. For example:

import torch
from faesm.esm import FAEsmForMaskedLM

# Step 1: Load the tokenizer and FAESM model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device).eval().to(torch.float16)
# Step 2: Prepare a sample input sequence
sequence = "MAIVMGRWKGAR"
inputs = model.tokenizer(sequence, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Step 3: Run inference with the FAESM model
outputs = model(**inputs)
# Step 4: Process and print the output logits and repr.
print("Logits shape:", outputs['logits'].shape)  # (batch_size, sequence_length, num_tokens)
print("Repr shape:", outputs['last_hidden_state'].shape)  # (batch_size, sequence_length, hidden_size)
# Step 5: start the repo if the code works for u!

Training [WIP]

Working on an example training script for MLM training on Uniref50. For now, you can use the same training logic as how you would train the official ESM since the FAESM has no difference in the model architecture. It's recommended to use the flash attention for training. Because in the forward pass, it unpads the input sequences to remove all the padding tokens, which 1) speeds up the training & reduces the memory usage and 2) it doesn't require batching sequences of similar length to avoid padding. Also, SDPA is still a good alternative if you can't install flash attention.

Benchmarking

Below we benchmark the peak memory usage and inference time of FAESM with the official ESM2 and show that FAESM can save the memory usage up to 60% and inference time up to 70% (length 1000). The benchmarking is done on ESM-650M with batch size 8, and a single A100 with 80GB of memory.

benchmark

You can reproduce the benchmarking by running the following command:

pytest tests/benchmark.py

To test errors between FAESM and the official ESM2 implementation, you can run:

pytest tests/test_compare_esm.py

TODOs

  • Training script
  • Integrate FAESM into EMSFold

Appreciation

  • The Rotary code is from esm-efficient.
  • The ESM modules and the SDPA attention module are inspired by ESM and DPLM.
  • I want to highlight that esm-efficient also supports Flash Attention and offers more features such as quantitation and lora. Please check it out!!

This project started as a mutual disappointment with Alex Tong(@atong01) about why there is no efficient implementation of ESM (wasted a lot compute in training pLMs :(. He later helped me debugged the precision errors in my implementation and organize this repo. In the process, I talked @MuhammedHasan regarding his ESM-efficent implementation (see the issues 1 and 2), and also Tri Tao about flash attention (see the issue). Of course shoutout to the ESM teams for creating the ESM family. None of the pieces of code would be possible without their help. @MuhammedHasan

Citation

Please cite this repo if you use it in your work.

@misc{faesm2024,
  author       = {Fred Zhangzhi Peng and contributors},
  title        = {FAESM: An efficient PyTorch implementation of Evolutionary Scale Modeling (ESM)},
  year         = {2024},
  howpublished = {\url{https://github.com/pengzhangzhi/faesm}},
  note         = {Efficient PyTorch implementation of ESM with FlashAttention and Scalar Dot-Product Attention (SDPA)},
  abstract     = {FAESM is a drop-in replacement for the official ESM implementation, designed to save up to 60% memory usage and 70% inference time, while maintaining compatibility with the ESM API.},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

faesm-0.0.4.tar.gz (19.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

faesm-0.0.4-py3-none-any.whl (16.1 kB view details)

Uploaded Python 3

File details

Details for the file faesm-0.0.4.tar.gz.

File metadata

  • Download URL: faesm-0.0.4.tar.gz
  • Upload date:
  • Size: 19.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.15

File hashes

Hashes for faesm-0.0.4.tar.gz
Algorithm Hash digest
SHA256 ca28ce93288a93eb36f1bcd5245f2ead41839c1a45ba6a676e259ceca096b75e
MD5 41c72f949d9b46b4b8b08ca3b1e6874a
BLAKE2b-256 233315d5956ed957a005f476bf8fa773592094bf3d112442e33bcc27724664d2

See more details on using hashes here.

File details

Details for the file faesm-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: faesm-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 16.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.15

File hashes

Hashes for faesm-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 f2a8f639c9dd936b256fa202747688b5cdad8effe13f1e15cf6292e240c64c6a
MD5 1a5d4b26821dd5cd1140b25509ba87db
BLAKE2b-256 269f47b387cac5c629276c5b760b0c07ba86821f0cc3975f0cad103f27033289

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page