Reverse Distillation for Protein Language Models
Project description
PLM Reverse Distillation
Protein language models (PLMs) scale poorly: for many tasks, mid-sized models often outperform the largest in the same family. Reverse Distillation addresses this by decomposing large PLM representations into orthogonal subspaces guided by smaller models of the same family. The resulting embeddings have a Matryoshka-style nested structure — the first k dimensions of a larger model's embedding exactly match the smaller model's representation — ensuring larger reverse-distilled models consistently outperform smaller ones.
On ProteinGym benchmarks, reverse-distilled ESM-2 variants outperform their respective baselines at the same embedding dimensionality, with the reverse-distilled 15B model achieving the strongest performance.
Installation
Requires Python ≥ 3.12 and uv.
git clone https://github.com/rohitsinghlab/plm_reverse_distillation.git
cd plm_reverse_distillation
uv lock && uv sync
uv pip install -e '.[dev]'
source .venv/bin/activate
Quick Start
See inference_tutorial.ipynb for a step-by-step walkthrough of loading pretrained models and extracting embeddings.
Pretrained scalers for all ESM-2 model pairs (8M → 35M → 150M → 650M → 3B → 15B) are available on HuggingFace and loaded automatically via the model registry:
singhlab/plm_reverse_distillation
Available Models
All models use PCR regression and PCA for dimensionality reduction. Each model applies the full chain of scalers from ESM-2 8M up to the target size.
| Model name | Chain | Output dim |
|---|---|---|
esm2.rd/35M |
8M → 35M | 480 |
esm2.rd/150M |
8M → 35M → 150M | 640 |
esm2.rd/650M |
8M → 35M → 150M → 650M | 1280 |
esm2.rd/3B |
8M → 35M → 150M → 650M → 3B | 2560 |
esm2.rd/15B |
8M → 35M → 150M → 650M → 3B → 15B | 5120 |
Scripts
Embedding extraction
Extract embeddings from a FASTA file using a pretrained RD model:
python scripts/extract.py \
--fasta_file proteins.fasta \
--output_dir embeddings/ \
--repr_type mean \
--batch_size 32
Key arguments: --repr_type (per_tok / mean / bos), --repr_layers, --batch_size, --truncation_seq_length.
Training scalers
Train new scalers on your own data:
python scripts/train.py \
--dataset_path proteins.fasta \
--scalar_path scalers/ \
--regressor_type pcr \
--scaler_type rd \
--n_pretrained_seqs 5000
Key arguments: --regressor_type (linear / ridge / pcr), --scaler_type (rd / naive), --pca_type (incremental / fbpca), --n_pretrained_seqs.
Citation
If you use reverse distillation, please cite:
@inproceedings{catrina2026reverse,
title = {Reverse Distillation: Consistently Scaling Protein Language Model Representations},
author = {Catrina, Darius and Bepler, Christian and Sledzieski, Samuel and Singh, Rohit},
booktitle = {International Conference on Learning Representations},
year = {2026}
}
License
This project is licensed under the MIT License — see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file reverse_distillation-0.1.0.tar.gz.
File metadata
- Download URL: reverse_distillation-0.1.0.tar.gz
- Upload date:
- Size: 16.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c278b787425a18b62735d37235ef4745ce0da2fe9961bb5f9af17fe12cca4ba7
|
|
| MD5 |
dc402005a3def2832f82f4d84450a468
|
|
| BLAKE2b-256 |
1bbd5e78f14cdd39572b7baf08e174169d126bc61f969db8a0dea2034d61f7ed
|
File details
Details for the file reverse_distillation-0.1.0-py3-none-any.whl.
File metadata
- Download URL: reverse_distillation-0.1.0-py3-none-any.whl
- Upload date:
- Size: 17.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8a5efda80b77147e513279e99bced5fa672c4b8537510cadd178fc17c161d0be
|
|
| MD5 |
3bcec296b1837d9acce57286d6e3d59c
|
|
| BLAKE2b-256 |
c61e80d17f9c71851fd4fcb0e7dba8b5bbdfb3310ffded1d1a023f6915340bd0
|