PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling
Project description
PRISM
Partitioning Residue Identity in Somatic Maturation
This is the official repository for the paper:
Explicit representation of germline and non-germline residues improves antibody language modeling
PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.
Part 1: User Guide
Everything you need to run inference or finetune PRISM on your own antibody data.
Installation
pip install prism-antibody
Or install from source:
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .
Verify Installation
import prism
print(prism.__version__) # 0.3.0
Quick Start: Inference
import prism
# Load from Hugging Face Hub (auto-downloads and caches)
model = prism.pretrained("RomeroLab-Duke/prism-antibody")
# Or load from a local checkpoint
# model = prism.pretrained("path/to/checkpoint.ckpt")
# Extract germline log-probabilities — [L, 20] numpy array
gl = model.extract_GL_logit("EVQLVESGGGLVQPGGSLRL")
# Extract non-germline log-probabilities — [L, 20] numpy array
ngl = model.extract_NGL_logit("EVQLVESGGGLVQPGGSLRL")
# Marginalized log-probs: logsumexp(GL, NGL) — [L, 20]
marg = model.extract_marginalized_logit("EVQLVESGGGLVQPGGSLRL")
# Full 53-vocab log-probs — [L, 53]
full = model.extract_full_logit("EVQLVESGGGLVQPGGSLRL")
# Alpha gating values (GL/NGL mixture weights) — [L]
alpha = model.extract_alpha("EVQLVESGGGLVQPGGSLRL")
# Mean-pooled embedding — [H]
emb = model.extract_embedding("EVQLVESGGGLVQPGGSLRL")
# Perplexity — scalar
ppl = model.perplexity("EVQLVESGGGLVQPGGSLRL")
All methods accept a single string or a list of strings:
# Batch inference
sequences = ["EVQLVESGGGLVQ", "DIQMTQSPSSLSA", "QVQLVQSGAEVKK"]
embeddings = model.extract_embedding(sequences) # [3, H] numpy array
ppls = model.perplexity(sequences) # [3] numpy array
Available Methods
| Method | Returns | Description |
|---|---|---|
extract_GL_logit(seq) |
[L, 20] |
Germline (uppercase) log-probabilities |
extract_NGL_logit(seq) |
[L, 20] |
Non-germline (lowercase) log-probabilities |
extract_marginalized_logit(seq) |
[L, 20] |
logsumexp(GL, NGL) per amino acid |
extract_full_logit(seq) |
[L, 53] |
Full vocabulary log-probabilities |
extract_alpha(seq) |
[L] |
Per-position alpha gating values |
extract_embedding(seq) |
[H] |
Mean-pooled backbone embedding |
embed(seq, mode=...) |
varies | Embeddings ("mean", "per_residue", "cls") |
perplexity(seq) |
scalar | Pseudo-perplexity |
pseudo_log_likelihood(seq) |
scalar | Pseudo-log-likelihood score |
predict_origin(seq) |
dict | GL/NGL origin predictions per residue |
score_mutations(wt, mut) |
scalar | Log-likelihood ratio at mutated positions |
logits(seq, head=...) |
tensor | Raw logits from any head |
Column order for 20-AA outputs follows PrismModel.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY".
Finetuning on Your Data
Finetune PRISM on your own antibody sequences with a single method call:
import prism
model = prism.pretrained("RomeroLab-Duke/prism-antibody")
best_checkpoint = model.finetune(
data_path="my_antibodies.parquet", # parquet, pickle, or csv
output_dir="outputs/my_finetune",
max_steps=5000,
learning_rate=1e-4,
batch_size=32,
)
# Model is now finetuned — use immediately
gl = model.extract_GL_logit("EVQLVESGGGLVQ...")
Data Format
Your data file needs at least one of these columns:
| Column | Required | Description |
|---|---|---|
HEAVY_CHAIN_AA_SEQUENCE |
At least one | Heavy chain amino acid sequence |
LIGHT_CHAIN_AA_SEQUENCE |
At least one | Light chain amino acid sequence |
Optional columns (auto-detected):
| Column | Description |
|---|---|
split |
"train" / "valid" / "test" (auto-generated 90/5/5 if absent) |
hc_mut_codes, lc_mut_codes |
NGL mutation codes (e.g. "S30A;T52N") |
v_gene_heavy, j_gene_heavy, v_gene_light, j_gene_light |
Gene labels |
region_mask_heavy, region_mask_light |
Region annotations |
Accepted file formats: .parquet, .pkl/.pickle, .csv.
Finetune Parameters
model.finetune(
data_path="data.parquet",
output_dir="prism_finetune_output",
# Training
max_steps=5000, # total training steps
learning_rate=1e-4, # peak LR (cosine schedule with warmup)
batch_size=32, # per-device batch size
warmup_steps=500, # linear warmup steps
weight_decay=0.01, # AdamW weight decay
mask_prob=0.15, # MLM masking probability
gradient_accumulation_steps=1,
# Trainer
devices=1, # number of GPUs
precision="bf16-mixed", # training precision
val_check_interval=500, # validate every N steps
# Data
num_workers=4,
seed=42,
)
Part 2: Developer Guide
For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.
Development Installation
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
# Install with dev + analysis extras
pip install -e ".[dev,analysis]"
Project Structure
prism/
├── src/prism/ # Core Python package
│ ├── __init__.py # Package exports
│ ├── api.py # High-level inference & finetune API
│ ├── model.py # SFT_ESM2 PyTorch Lightning module
│ ├── io_utils.py # Dataset & DataModule classes
│ ├── multimodal_io.py # Gene vocabulary & antibody dataset
│ └── utils.py # Utility functions
│
├── configs/ # Training configuration files
│ ├── v34_pretrain.yaml # Stage 1: Pretraining config
│ └── v34_1b_finetune.yaml # Stage 2: Finetuning config
│
├── script/ # Executable scripts
│ ├── train_esm.py # Main training script
│ ├── inference_esm.py # Basic inference
│ ├── inference_esm_with_logprobs.py
│ ├── upload_to_hub.py # Upload checkpoint to HF Hub
│ │
│ ├── data/ # Data processing pipeline (1-8)
│ │ ├── 1.processing_and_filtering_*.py
│ │ ├── 2.visualize_unpaired_statistics.py
│ │ ├── 3.filter_by_p90_and_save.py
│ │ ├── 4-6.cluster_*.py
│ │ ├── 7.extract_gene_information.py
│ │ └── 8.extract_data_for_probe.py
│ │
│ └── analyze/ # Analysis & evaluation scripts
│ ├── 1.pppl_calculation/ # Perplexity calculations
│ ├── 2.gl-ngl_calculation/ # GL/NGL embeddings & linear probes
│ ├── 3.zero-shot/ # Zero-shot prediction benchmarks
│ └── 4.thera-sabdab/ # Controllable generation experiments
│
├── tests/ # Test suite
│ └── test_api.py # API tests (59 tests)
│
├── pyproject.toml # Package configuration
├── requirements.txt # Explicit dependencies
└── README.md
Core Modules
model.py — SFT_ESM2
The central PyTorch Lightning module. Key components:
- Base: ESM2 transformer (HuggingFace) with optional SwiGLU activation
- Multi-head architecture: AA head + Origin head + Alpha gating + Final head
- Gene conditioning: V/J gene embeddings + region embeddings
- Loss: Focal loss with region-balanced and CDR-boosted variants
io_utils.py — Data Loading
SeqSeqDataset: Handles paired/unpaired antibody sequences with germline reconstructionSFTDataModule: Standard PyTorch Lightning DataModuleLazyShardedDataModule: Memory-efficient sharded loading for large datasetsmake_collate_fn_multihead: Collate with 80/10/10 masking, gene encoding, region IDs
multimodal_io.py — Gene & Region Handling
GeneVocabulary: Maps V/J gene strings to integer IDsAntibodyDataset/AntibodyDataCollator/AntibodyMLMCollator: Multi-modal data handling
Training from Scratch
Two-Stage Training Protocol
Stage 1 — Pretraining on large unpaired OAS dataset (~60M+ sequences):
python script/train_esm.py --config configs/v34_pretrain.yaml
Stage 2 — Finetuning on paired antibody sequences (~764K):
# Set pretrained_checkpoint_path in the config to point to Stage 1 best checkpoint
python script/train_esm.py --config configs/v34_1b_finetune.yaml
Multi-GPU Training
CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml
Key Config Options
data:
data_path: "path/to/data"
batch_size: 256
mask_prob: 0.15
model:
model_identifier: "esm2_t12_35M_UR50D"
use_multihead_architecture: true
use_alpha_gating: true
ngl_loss_alpha: 3.0
training:
max_steps: 100000
peak_learning_rate: 4e-4
warmup_steps: 2000
gradient_accumulation_steps: 8
trainer:
devices: 4
precision: "bf16-mixed"
Analysis Pipeline
Located in script/analyze/, numbered by experiment stage:
| Stage | Directory | Description |
|---|---|---|
| 1 | pppl_calculation/ |
Pseudo-perplexity comparison across models |
| 2 | gl-ngl_calculation/ |
Embedding extraction, linear probes, UMAP |
| 3 | zero-shot/ |
Binding affinity and developability benchmarks |
| 4 | thera-sabdab/ |
Controllable generation and mutation recovery |
Data Processing Pipeline
Located in script/data/, processes OAS (Observed Antibody Space) data:
| Step | Purpose |
|---|---|
| 1 | Filter OAS data, identify NGL positions |
| 2 | Visualize data distribution |
| 3 | Quality filtering by sequence coverage |
| 4-6 | MMseqs2 sequence clustering & deduplication |
| 7 | Parse V/J gene annotations |
| 8 | Prepare data for linear probe training |
Testing
# Run all tests
pytest tests/ -v
# Run specific test class
pytest tests/test_api.py::TestFinetune -v
Linting
black --line-length 100 src/
isort --profile black --line-length 100 src/
flake8 src/
mypy src/
License
MIT License — see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file prism_antibody-0.3.0.tar.gz.
File metadata
- Download URL: prism_antibody-0.3.0.tar.gz
- Upload date:
- Size: 78.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eff0664c2b65f9f23f57f8fab350671103ef021b02aaeabf3709a16df76541eb
|
|
| MD5 |
a4f9d1111712b2e295b674a974addcfc
|
|
| BLAKE2b-256 |
d7a0946ba4bcf9391628cc27634b3ecf001fc91e4e9e930dcaabc6b8baa63075
|
File details
Details for the file prism_antibody-0.3.0-py3-none-any.whl.
File metadata
- Download URL: prism_antibody-0.3.0-py3-none-any.whl
- Upload date:
- Size: 70.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5d9960b6843f8e840ef379aa45b81f7f7ee8dd0bafd01e40d266b14b3e3e0d05
|
|
| MD5 |
a2d56813f6542139ebf4ffc5d8ccd1b7
|
|
| BLAKE2b-256 |
7a51afe862ca29f8cda2837631c53c7f2e2b737d66788d03c04495fb765db59c
|