Skip to main content

Calculate Fréchet distance for protein language model embeddings

Project description

plm-fid: Fréchet Distance from pLM Embeddings

cli_demo

This tool computes the Fréchet Distance between two sets of protein sequences based on their protein language model (pLM) embeddings. This metric is often used for image generation (see original paper), but has also seen some popularity for protein generation (see Acknowledgements).

When you pass in FASTA files, the specified model is loaded from the HuggingFace Hub using transformers.

[!NOTE]
If it's your first time using a given model, the weights will be downloaded and cached locally (this can take some time, especially for large models).

Each protein sequence is then embedded using the selected pLM. We apply mean pooling across the sequence dimension to obtain a fixed-size vector for each protein. After embedding both sets, we compute the mean vector and covariance matrix of each set. These summary statistics are then used to calculate the Fréchet Distance.

Installation

[!IMPORTANT]
Python >= 3.10 required. sentencepiece will give some issues with Python 3.13, so I recommend any version >=3.10, <3.13.

# if you have torch>=2.0 already installed
pip install plm-fid

# else
pip install plm-fid[torch]

Usage

CLI

plm-fid setA.fasta setB.fasta

[!TIP] CLI expects paths to files only. These should be either .fasta files (raw protein sequences) or .npy or .pt files (pre-computed embeddings)

[!TIP] To see all available options, run:

plm-fid --help
CLI/API Arguments Description
model-name The protein language model to use. Please specify a lowercase string, such as esm2_8m, protbert, or antiberta2_cssp. With API, see available models with FrechetProteinDistance.available_models(). For CLI use --help. Defaults to esm2_650m.
device The device to run the model on, e.g., cuda:0 or cpu. Defaults to cuda if available, otherwise cpu.
max-length Maximum length for each protein sequence. Longer sequences are truncated according to the selected truncation style. Some models may require a smaller max length (e.g., antiberta2_cssp supports up to 254). Defaults to 1000.
truncation-style How to truncate sequences longer than max-length. Use end to truncate from the back, or center to remove from the center, preserving N- and C-termini. Defaults to center.
batch-size Number of sequences to embed per batch. Adjust according to available memory. Defaults to 1.
save-embeddings Whether to save the computed embeddings to .npy files. Useful for reuse or debugging. Disabled by default.
output-dir Directory to save output files if --save-embeddings is enabled. Defaults to current directory (.).
CLI Only Arguments Description
round Number of decimal places to round the final Fréchet distance result to. Defaults to 2.
verbose Show progress messages. Disabled by default.

Python API

from plm_fid import FrechetProteinDistance
import numpy as np
import torch

fid = FrechetProteinDistance(model_name="esmplusplus_small")

# Using saved NumPy arrays or PyTorch tensors
emb_a = np.load("embeddings_a.npy")       # shape: [N, D]
emb_b = torch.load("embeddings_b.pt")     # shape: [N, D]

distance = fid.compute_fid(emb_a, emb_b)

[!NOTE] The API accepts both file paths and in-memory arrays/tensors. Argument names use underscores instead of dashes (e.g., model_name).

Automatic Format Resolution in API

The compute_fid() method accepts:

  • .fasta file paths
  • .npy or .pt file paths
  • In-memory NumPy arrays or PyTorch tensors

Any combination of the above are accepted by compute_fid.

# FASTA
fid.compute_fid("set_a.fasta", "set_b.fasta")

# Embeddings from disk
fid.compute_fid("emb_a.npy", "emb_b.pt")

# Mixed input
import numpy as np
set_a = np.random.randn(10, 1280)
fid.compute_fid(set_a "emb_b.pt")

[!IMPORTANT] When using pre-computed embeddings, a stacked array/tensor with shape [batch, plm dim] is expected.

Important Notes

  • Computing the FID will take longer for models with bigger dimensions (e.g. ESM2 3B) because of the larger covariance matrix.
  • When mixing FASTA with pre-computed embeddings, make sure the model used to embed the FASTA file is the same as the one used to generate the .npy or .pt embeddings. A warning will be issued if there’s a mismatch.
    • If the embedding dimensions differ, calculate_frechet_distance() in distance.py will raise an error.

[!WARNING]
However, if different models produce embeddings of the same dimension, this will not raise an error, but the FID is likely meaningless.

AntiBERTa2-Specific Notes

  • AntiBERTa2 has a max sequence length of 254. This will be enforced automatically.
  • For paired-chain FASTA input, format each entry as:
>name
heavy_sequence|light_sequence

[!IMPORTANT] If you're using a standard protein language model with paired-chain data, note that the | character may be treated as an unknown token by most tokenizers. This typically doesn't cause a crash, but it can affect embedding quality.

Examples

Here are two example use cases of pLM-based FID. These were primarily sanity-checks for myself to make sure the metric behaves as expected. The central idea in both examples is the following: Given two known protein distributions, $P$ and $Q$, that are believed to be distinct:

  • Sample a reference set and set $A$ from distribution $P$
  • Sample set $B$ from distribution $Q$.

then our hypothesis is:

\texttt{plmFID}(\text{ref}, A) < \texttt{plmFID}(\text{ref}, B)

where $\texttt{plmFID}$ refers to the Fréchet distance computed using the mean and covariance of their pLM embeddings.

Distinguishing CATH Classes

Using the CATH S20 dataset, this example tests whether pLM-based Fréchet distance can distinguish between CATH classes, specifically Class 1 and Class 2 proteins. Class 1 proteins are primarily alpha-helical, while Class 2 are mostly beta-sheet. Even though the Class is based on secondary-structure, we expect pLM embeddings to encode features that reflect these differences in fold.

Let $P$ be the distribution of Class 1 proteins, and $Q$ the distribution of Class 2 proteins.

We provide the FASTA files in examples/cath containing 1953 sequences each.

  • reference.fasta: from $P$
  • class1.fasta: set $A$ from $P$
  • class2.fasta: set $B$ from $Q$

Results

Model plmFID(ref, $A$) plmFID(ref, $B$)
ESM2 (8M) 0.11 2.52
ESM2 (35M) 0.16 2.15
ESM2 (150M) 0.15 1.70
ESM2 (650M) 0.38 2.18
ESM2 (3B) 1.07 3.86
ESM2 (15B) 7.94 18.17
ProtBert 0.08 0.49
ProtBert BFD 0.07 0.75
ProtT5 0.15 0.74
ESM++ (S) 0.01 0.07
ESM++ (L) 0.01 0.11

Across all pLMs, the expected behavior is observed. However, maybe this task is simply too trivial, so we move on to the second example.

Distinguishing antibodies binding to the SARS-CoV-2 Spike Protein's RBD vs NTD

The SARS-CoV-2 spike (S) protein contains two major domains that are frequent targets of neutralizing antibodies: the receptor-binding domain (RBD) and the N-terminal domain (NTD). Although both domains are part of the same protein, they differ in sequence and structure, raising the question of whether pLM embeddings can capture these biologically meaningful distinctions. In principle, antibodies that bind to the RBD versus the NTD should differ in their sequence features, as each domain presents distinct epitopes that shape antibody binding preferences.

From the CoV-AbDab database, I filtered for antibodies with both heavy and light chains that bind to either the NTD or RBD. There were only 587 NTD antibodies compared to 7141 RBD antibodies, so to get equal number of sequences, I randomly sampled 587 RBD antibodies to be the reference and another distinct 587 RBD antibodies to be $A$.

Let $P$ be the distribution of RBD-binding antibodies, and $Q$ the distribution of NTD-binding antibodies.

We provide the FASTA files in examples/cov containing 587 sequences each.

  • reference.fasta: from $P$
  • rbd.fasta: set $A$ from $P$
  • ntd.fasta: set $B$ from $Q$

Results

Model plmFID(ref, $A$) plmFID(ref, $B$)
AntiBERTa2-CSSP 3.45 7.30

As expected, the distance between the reference and NTD antibodies is greater than the distance between the reference and RBD antibodies.

Conclusion

These initial sanity checks suggest that pLM-based FID can distinguish between protein groups with known structural and functional differences. That said, this was primarily a personal validation, a way to ensure the metric and code behave as expected. I’m not claiming this is a novel, comprehensive, or sufficient benchmark!

Acknowledgements

FID

Here's the paper that introduced FID: GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium. I also wrote a short memo about FID and it's predecessor Inception Score here.

Existing works using Fréchet pLM Distance

I'd like to make it clear that pLM-based FID is not a new idea--here are the papers I've seen mention the use of pLM-based FID for assessing protein generations[^1].

Code

Protein language models

Model Team Link Paper
ESM2 FAIR Github Evolutionary-scale prediction of atomic-level protein structure with a language model
ProtBert (BFD) RostLab HuggingFace ProtTrans: Toward Understanding the Language of Life Through Self-Supervised Learning
ProtT5 RostLab HuggingFace ProtTrans: Toward Understanding the Language of Life Through Self-Supervised Learning
ESMplusplus (ESMC) EvolutionaryScale/Synthra HuggingFace/Github -
AntiBERTa2-CSSP Alchemab HuggingFace Enhancing Antibody Language Models with Structural Information

Footnotes

[^1]: I'll be updating the list as I see more examples, as well as for any that I've missed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

plm_fid-1.0.0.tar.gz (21.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

plm_fid-1.0.0-py3-none-any.whl (17.1 kB view details)

Uploaded Python 3

File details

Details for the file plm_fid-1.0.0.tar.gz.

File metadata

  • Download URL: plm_fid-1.0.0.tar.gz
  • Upload date:
  • Size: 21.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.0

File hashes

Hashes for plm_fid-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3a6830172cc84eef1f57e2ea9ebe233929f8f3bff80910166aeb0cf89bafcb7a
MD5 d2326159a526787a419d5f32df1ddeb4
BLAKE2b-256 46ca6bc4ea6810e06a8b17470217196ea84d371117ce06ed64467a4072c95eba

See more details on using hashes here.

File details

Details for the file plm_fid-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: plm_fid-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 17.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.0

File hashes

Hashes for plm_fid-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 af1cd57ea8429f7ba2338ed7deaac163b9d983bd8aa81c33d0bf71a56795dea2
MD5 649b16f325595f128c6a7221563303c6
BLAKE2b-256 1f8f4ed51c27678afcae02d6ea1ab3d65e3528683f1f83c9c7954be02a878fb9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page