PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling
Project description
PRISM
Partitioning Residue Identity in Somatic Maturation
This is the official repository for the paper:
Explicit representation of germline and non-germline residues improves antibody language modeling
PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.
Part 1: User Guide
Everything you need to run inference or finetune PRISM on your own antibody data.
Installation
pip install prism-antibody
Or install from source:
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .
Verify Installation
import prism
print(prism.__version__)
API Overview
PRISM has 3 core methods that cover all use cases:
| Method | Cost | Returns |
|---|---|---|
forward() |
1 forward pass | logits, embeddings, origin, alpha |
pseudo_log_likelihood() |
L forward passes | PLL, perplexity, per-position log-probs (4 modes) |
score_mutations() |
2M forward passes | masked marginal mutation scores (4 modes) |
All methods accept single strings or lists (batch). All support paired (VH+VL) and unpaired input.
forward() --- Logits, Embeddings, Everything
Single forward pass through the model. Returns all outputs as a dict of numpy arrays.
import prism
model = prism.pretrained("RomeroLab-Duke/prism-antibody")
# Paired heavy + light chain (typical usage)
result = model.forward(
heavy_chains="EVQLVESGGGLVQPGGSLRL",
light_chains="DIQMTQSPSSLSASVG",
)
# result["heavy"]["final_logits"] -> [L_vh, 53] alpha-gated combined logits
# result["heavy"]["aa_logits"] -> [L_vh, 33] AA head logits (pre-gating)
# result["heavy"]["origin_logits"] -> [L_vh] GL/NGL classification logits
# result["heavy"]["alpha"] -> [L_vh] gating values
# result["heavy"]["embedding"] -> [L_vh, H] per-residue hidden states
# result["light"] has the same keys for the light chain
Derive Any Signal You Need
import numpy as np
# GL/NGL log-probabilities (slice from 53-vocab)
gl_logits = result["heavy"]["final_logits"][:, model.GL_INDICES] # [L_vh, 20]
ngl_logits = result["heavy"]["final_logits"][:, model.NGL_INDICES] # [L_vh, 20]
# Mean-pooled embedding
vh_emb = result["heavy"]["embedding"].mean(axis=0) # [H]
vl_emb = result["light"]["embedding"].mean(axis=0) # [H]
# GL/NGL origin probability per residue
origin_prob = 1 / (1 + np.exp(-result["heavy"]["origin_logits"])) # sigmoid
# origin_prob > 0.5 = predicted NGL (somatic mutation)
Masked Prediction
Mask specific positions to get context-only predictions:
result = model.forward(
heavy_chains="EVQLVESGGGLVQPGGSLRL",
light_chains="DIQMTQSPSSLSASVG",
mask_positions=[5, 10, 15], # mask these heavy chain positions
)
Batch Inference
results = model.forward(
heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# List of 2 dicts, each with "heavy" and "light" sub-dicts
Unpaired Input
For heavy-only, light-only, or any single-chain input:
result = model.forward("EVQLVESGGGLVQPGGSLRL") # positional arg
result = model.forward(heavy_chains="EVQLVESGGGLVQPGGSLRL") # explicit heavy
result = model.forward(light_chains="DIQMTQSPSSLSASVG") # explicit light
# Returns flat dict (no "heavy"/"light" nesting)
pseudo_log_likelihood() --- PLL and Perplexity
Masked marginal scoring: mask each position one at a time, predict, accumulate log P(true token). Returns all 4 scoring modes in one pass with progress bars.
# Paired (typical)
result = model.pseudo_log_likelihood(
heavy_chains="EVQLVESGGGLVQPGGSLRL",
light_chains="DIQMTQSPSSLSASVG",
)
# {
# "marginalized": {"pll": -45.3, "perplexity": 2.34, "per_position": [L_vh + L_vl]},
# "gl": {"pll": -50.1, "perplexity": 2.71, "per_position": [L_vh + L_vl]},
# "ngl": {"pll": -48.2, "perplexity": 2.56, "per_position": [L_vh + L_vl]},
# "exact": {"pll": -50.1, "perplexity": 2.71, "per_position": [L_vh + L_vl]},
# }
# Quick access
ppl = result["marginalized"]["perplexity"]
per_pos = result["gl"]["per_position"] # per-residue GL log-probs
# Batch paired (with progress bar)
results = model.pseudo_log_likelihood(
heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
Scoring Modes
| Mode | What it scores | Use case |
|---|---|---|
marginalized |
logsumexp(GL, NGL) |
General-purpose |
gl |
Uppercase (germline) token log-prob | Germline likeness |
ngl |
Lowercase (non-germline) token log-prob | Somatic mutation preference |
exact |
Raw token log-prob in 53-vocab | Direct 53-vocab scoring |
Unpaired
result = model.pseudo_log_likelihood("EVQLVESGGGLVQPGGSLRL")
score_mutations() --- Mutation Effect Prediction
Masked marginal scoring at mutation positions. For each mutation, masks that position in both WT and mutant, computes the log-likelihood difference. Returns all 4 modes.
# Paired (typical)
result = model.score_mutations(
wt="EVQLVESGGGLVQPGGSLRL",
mutant="EVQLVASGGGLVQPGGSLRL", # V6A
wt_light_chains="DIQMTQSPSSLSA",
mut_light_chains="DIQMTQSPSSLSA", # same light chain, or different
)
# {
# "positions": [5], # 0-indexed mutation positions
# "marginalized": {"score": 0.42, "per_position": [1]},
# "gl": {"score": 0.31, "per_position": [1]},
# "ngl": {"score": 0.55, "per_position": [1]},
# "exact": {"score": 0.31, "per_position": [1]},
# }
# score > 0 = mutant preferred over WT at that position
# Batch (multiple WT/mutant pairs)
results = model.score_mutations(
wt=["EVQLVESGG", "DIQMTQSPS"],
mutant=["EVQLVASGG", "DIQMAQSPS"],
wt_light_chains=["DIQMT", "EVQLV"],
mut_light_chains=["DIQMT", "EVQLV"],
)
Unpaired
result = model.score_mutations(wt="ACDEF", mutant="GCKEF")
# result["positions"] == [0, 2] (A->G, D->K)
Reference
forward() Return Dict
| Key | Shape | Description |
|---|---|---|
final_logits |
[L, 53] |
Alpha-gated combined logits (53-vocab) |
aa_logits |
[L, 33] |
AA head logits, before gating |
origin_logits |
[L] |
GL/NGL binary classification logits |
alpha |
[L] |
Per-position gating values |
embedding |
[L, H] |
Per-residue hidden states from backbone |
When paired, returns {"heavy": {dict}, "light": {dict}} with the above keys in each sub-dict.
Index Constants
model.GL_INDICES--- 20 uppercase AA token indices in the 53-vocabmodel.NGL_INDICES--- 20 lowercase AA token indices in the 53-vocabmodel.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY"--- column order for the 20 AA indices
Input Modes
All 3 methods accept the same input patterns:
# Paired (returns "heavy"/"light" split for forward, combined for PLL/mutations)
model.forward(heavy_chains="VH...", light_chains="VL...")
# Unpaired (positional or keyword)
model.forward("SEQ...")
model.forward(heavy_chains="VH...")
model.forward(light_chains="VL...")
# Batch (list of strings)
model.forward(heavy_chains=["VH1", "VH2"], light_chains=["VL1", "VL2"])
Finetuning on Your Data
import prism
model = prism.pretrained("RomeroLab-Duke/prism-antibody")
best_checkpoint = model.finetune(
data_path="my_antibodies.parquet", # parquet, pkl, or csv
output_dir="outputs/my_finetune",
max_steps=5000,
learning_rate=1e-4,
batch_size=32,
)
# Model is now finetuned --- use immediately
result = model.forward(heavy_chains="EVQLVESGGGLVQ...", light_chains="DIQMT...")
Data Format
Your data file needs at least one of these columns:
| Column | Required | Description |
|---|---|---|
HEAVY_CHAIN_AA_SEQUENCE |
At least one | Heavy chain amino acid sequence |
LIGHT_CHAIN_AA_SEQUENCE |
At least one | Light chain amino acid sequence |
Optional columns (auto-detected):
| Column | Description |
|---|---|
split |
"train" / "valid" / "test" (auto-generated 90/5/5 if absent) |
hc_mut_codes, lc_mut_codes |
NGL mutation codes (e.g. "S30A;T52N") |
v_gene_heavy, j_gene_heavy, v_gene_light, j_gene_light |
Gene labels |
region_mask_heavy, region_mask_light |
Region annotations |
Finetune Parameters
model.finetune(
data_path="data.parquet",
output_dir="prism_finetune_output",
max_steps=5000, # total training steps
learning_rate=1e-4, # peak LR (cosine schedule with warmup)
batch_size=32, # per-device batch size
warmup_steps=500, # linear warmup steps
weight_decay=0.01, # AdamW weight decay
mask_prob=0.15, # MLM masking probability
gradient_accumulation_steps=1,
devices=1, # number of GPUs
precision="bf16-mixed", # training precision
val_check_interval=500, # validate every N steps
num_workers=4,
seed=42,
)
Part 2: Developer Guide
For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.
Development Installation
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e ".[dev,analysis]"
Project Structure
prism/
├── src/prism/ # Core Python package
│ ├── __init__.py # Package exports
│ ├── api.py # High-level inference & finetune API
│ ├── model.py # SFT_ESM2 PyTorch Lightning module
│ ├── io_utils.py # Dataset & DataModule classes
│ ├── multimodal_io.py # Gene vocabulary & antibody dataset
│ └── utils.py # Utility functions
│
├── configs/ # Training configuration files
├── script/ # Training, inference, analysis scripts
├── tests/ # Test suite
├── pyproject.toml # Package configuration
└── README.md
Training from Scratch
Two-Stage Training Protocol
Stage 1 --- Pretraining on large unpaired OAS dataset (~60M+ sequences):
python script/train_esm.py --config configs/v34_pretrain.yaml
Stage 2 --- Finetuning on paired antibody sequences (~764K):
python script/train_esm.py --config configs/v34_1b_finetune.yaml
Multi-GPU Training
CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml
Testing
pytest tests/ -v
License
MIT License --- see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file prism_antibody-0.5.0.tar.gz.
File metadata
- Download URL: prism_antibody-0.5.0.tar.gz
- Upload date:
- Size: 122.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
98a5bb1a9f2f789bbca6800206826cb1afcc430f531e521899a79822a38dbd06
|
|
| MD5 |
37ce8bc0d7c9ec2600d7b12c46c976ec
|
|
| BLAKE2b-256 |
290e6d5f8c4b91c67c29a72b6659cec5ceac55608779127d50e457dbadbb4e93
|
File details
Details for the file prism_antibody-0.5.0-py3-none-any.whl.
File metadata
- Download URL: prism_antibody-0.5.0-py3-none-any.whl
- Upload date:
- Size: 80.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4e1d85d167af32e10b846703aba34c51e2f52d74ef53aef4c87e31d20bbb6de9
|
|
| MD5 |
85489cc58241700204efff177381474b
|
|
| BLAKE2b-256 |
cc52fea612ddb718d807b7fbddfb9917698ac247857c27ad67f0238ed06269af
|