PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling
Project description
PRISM
Partitioning Residue Identity in Somatic Maturation
This is the official repository for the paper:
Explicit representation of germline and non-germline residues improves antibody language modeling
PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.
Part 1: User Guide
Everything you need to run inference with PRISM on your own antibody data.
Installation
pip install prism-antibody
Or install from source:
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .
Verify Installation
import prism
print(prism.__version__)
Quick Start
import prism
model = prism.pretrained("RomeroLab-Duke/prism-antibody")
tokenizer = model.get_tokenizer()
# Tokenize → model (standard HuggingFace-style pipeline)
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
# Forward pass — logits, embeddings, origin, alpha
result = model.forward(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
# Predict germline — revert somatic mutations
germline = model.predict_germline(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
# germline["heavy"]["predicted_germline"] → germline-reverted heavy chain
# germline["heavy"]["ngl_positions"] → which positions were mutated
Tokenizer
PrismTokenizer wraps the ESM2 tokenizer with PRISM's 53-token vocabulary (33 ESM2 base + 20 lowercase NGL tokens).
tokenizer = prism.PrismTokenizer() # standalone (no model needed)
tokenizer = model.get_tokenizer() # or from a loaded model
# Paired heavy + light chain
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
# inputs["input_ids"] -> [1, L_H+L_L+4] (CLS + VH + CLS + CLS + VL + EOS)
# inputs["attention_mask"] -> [1, L_H+L_L+4]
# Batch
inputs = tokenizer(
["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
light_chain=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
return_tensors="pt",
)
# Unpaired (single chain)
inputs = tokenizer("EVQLVESGGGLVQ", return_tensors="pt")
# Encode / decode (paired)
ids = tokenizer.encode_paired("EVQLV", "DIQMT")
heavy, light = tokenizer.decode_paired(ids) # ("EVQLV", "DIQMT")
# Encode / decode (unpaired)
ids = tokenizer.encode("EVQLV") # [CLS, E, V, Q, L, V, EOS]
seq = tokenizer.decode(ids) # "EVQLV"
NGL-Aware Tokenization
By default, all amino acids are tokenized as uppercase (GL) tokens — this is the standard mode and matches the training format. Use preserve_case=True only when you need exact mode in pseudo_log_likelihood(), which scores each position using its actual GL or NGL log-probability.
# For exact PLL: lowercase = NGL (somatic mutation) positions
inputs = tokenizer("EvQLvESGGglvq", preserve_case=True, return_tensors="pt")
# 'v', 'g', 'l', 'v' → NGL token IDs; 'E', 'Q', 'L', ... → GL token IDs
result = model.pseudo_log_likelihood(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# result["exact"] now uses NGL log-prob at lowercase positions
GL/NGL Token Mappings
tokenizer.gl_token_ids # {"A": 5, "C": 23, ...} — 20 uppercase (germline)
tokenizer.ngl_token_ids # {"a": 33, "c": 34, ...} — 20 lowercase (non-germline)
tokenizer.gl_to_ngl # {5: 33, 23: 34, ...} — GL→NGL token ID mapping
tokenizer.vocab_size # 53
API Overview
PRISM has 5 core methods. All accept pre-tokenized input_ids (recommended) or raw strings.
| Method | Cost | Returns |
|---|---|---|
forward() |
1 forward pass | logits, embeddings, origin, alpha |
pseudo_log_likelihood() |
ceil(L / batch_size) forward passes | PLL, perplexity, per-position log-probs (4 modes) |
score_mutations() |
2 × ceil(M / batch_size) forward passes | masked marginal mutation scores (4 modes) |
predict_germline() |
1 forward pass | predicted germline sequences, NGL positions/probs |
generate() |
L + N forward passes | PLL-guided antibody variants |
forward() --- Logits, Embeddings, Everything
Single forward pass through the model. Returns all outputs as numpy arrays.
import prism
import numpy as np
model = prism.pretrained("RomeroLab-Duke/prism-antibody")
tokenizer = model.get_tokenizer()
# Standard: tokenize → forward (paired)
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
result = model.forward(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
# result["final_logits"] -> [L, 53] alpha-gated combined logits
# result["aa_logits"] -> [L, 33] AA head logits (pre-gating)
# result["origin_logits"] -> [L] GL/NGL classification logits
# result["alpha"] -> [L] gating values
# result["embedding"] -> [L, H] per-residue hidden states
# GL/NGL log-probabilities (slice from 53-vocab)
gl_logits = result["final_logits"][:, model.GL_INDICES] # [L, 20]
ngl_logits = result["final_logits"][:, model.NGL_INDICES] # [L, 20]
Batch, Unpaired (string convenience)
# Batch (returns list of {"heavy": {...}, "light": {...}})
results = model.forward(
heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# String convenience (paired)
result = model.forward(heavy_chains="EVQLVESGGGLVQ", light_chains="DIQMTQSPSSLSA")
# Unpaired (single chain)
result = model.forward("EVQLVESGGGLVQPGGSLRL")
pseudo_log_likelihood() --- PLL and Perplexity
Masks each position one at a time, accumulates log P(true token). Returns 4 scoring modes in one pass.
# Standard: tokenize → PLL
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
result = model.pseudo_log_likelihood(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# {
# "marginalized": {"pll": -45.3, "perplexity": 2.34, "per_position": [L]},
# "gl": {"pll": -50.1, "perplexity": 2.71, "per_position": [L]},
# "ngl": {"pll": -48.2, "perplexity": 2.56, "per_position": [L]},
# "exact": {"pll": -50.1, "perplexity": 2.71, "per_position": [L]},
# }
ppl = result["marginalized"]["perplexity"]
NGL-Aware Scoring with exact Mode
When the input contains NGL tokens (lowercase via preserve_case=True --- see NGL-Aware Tokenization), the exact mode scores each position using its actual token: uppercase log-prob for GL positions, lowercase log-prob for NGL positions.
Scoring Modes
All modes are computed from the 53-vocab alpha-gated logits. gl, ngl, and marginalized extract the GL/NGL slots and combine them back into 20-AA probabilities.
| Mode | What it scores | Use case |
|---|---|---|
marginalized |
logsumexp(GL, NGL) per AA |
General-purpose scoring |
gl |
Uppercase (GL) token log-prob | Germline likeness |
ngl |
Lowercase (NGL) token log-prob | Somatic mutation preference |
exact |
Actual input token log-prob | NGL-aware scoring (with preserve_case=True) |
Batch Processing
batch_size controls how many masked positions are processed in a single forward pass. Higher values use more GPU memory but run faster.
# Fast: 64 positions per forward pass (needs ~2x memory vs default)
result = model.pseudo_log_likelihood(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
batch_size=64,
)
For multiple sequences, pass a list --- they are scored sequentially, each with the same batch_size parallelism:
# Multiple sequences (processed one at a time, results in order)
results = model.pseudo_log_likelihood(
heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# results[0] → first pair, results[1] → second pair
String Convenience
# Paired
result = model.pseudo_log_likelihood(
heavy_chains="EVQLVESGGGLVQ",
light_chains="DIQMTQSPSSLSA",
)
# Unpaired
result = model.pseudo_log_likelihood("EVQLVESGGGLVQPGGSLRL")
score_mutations() --- Mutation Effect Prediction
Masked marginal scoring at mutation positions. For each mutated position, masks that position in both WT and mutant, runs a forward pass, and computes the log-likelihood difference. Returns all 4 scoring modes.
# Standard: tokenize → score (paired)
wt_inputs = tokenizer("EVQLVESGGGLVQPGGSLRL", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
mut_inputs = tokenizer("EVQLVASGGGLVQPGGSLRL", light_chain="DIQMTQSPSSLSA", return_tensors="pt") # V6A
result = model.score_mutations(
wt_input_ids=wt_inputs["input_ids"],
mut_input_ids=mut_inputs["input_ids"],
)
# {
# "positions": [5], # 0-indexed mutation positions (detected from token diff)
# "marginalized": {"score": 0.42, "per_position": [1]},
# "gl": {"score": 0.31, "per_position": [1]},
# "ngl": {"score": 0.55, "per_position": [1]},
# "exact": {"score": 0.31, "per_position": [1]},
# }
# score > 0 = mutant preferred over WT
Batch Processing
batch_size controls how many mutation positions are masked per forward pass. For sequences with many mutations, higher values are faster.
result = model.score_mutations(
wt_input_ids=wt_inputs["input_ids"],
mut_input_ids=mut_inputs["input_ids"],
batch_size=64,
)
For multiple WT/mutant pairs, pass lists --- they are scored sequentially:
results = model.score_mutations(
wt=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
mutant=["EVQLVASGGGLVQ", "QVQLVQSGAEVAK"],
wt_light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
mut_light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# results[0] → first pair, results[1] → second pair
String Convenience
# Paired
result = model.score_mutations(
wt="EVQLVESGGGLVQPGGSLRL",
mutant="EVQLVASGGGLVQPGGSLRL",
wt_light_chains="DIQMTQSPSSLSA",
mut_light_chains="DIQMTQSPSSLSA",
)
# Unpaired
result = model.score_mutations(
wt="EVQLVESGGGLVQPGGSLRL",
mutant="EVQLVASGGGLVQPGGSLRL",
)
predict_germline() --- Germline Sequence Prediction
Predicts the unmutated germline sequence from a somatically hypermutated antibody. Uses the origin head to identify non-germline (NGL) positions and reverts them to the top-scoring germline amino acid --- all in a single forward pass.
import prism
model = prism.pretrained("RomeroLab-Duke/prism-antibody")
tokenizer = model.get_tokenizer()
# Standard: tokenize → predict germline (paired)
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
result = model.predict_germline(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# {
# "heavy": {
# "sequence": "EVQLVESGGGLVQ", # original
# "predicted_germline": "EVQLVESGGGLVQ", # germline-reverted
# "ngl_positions": [5, 8], # 0-indexed NGL positions
# "ngl_count": 2,
# "ngl_probs": array([0.02, ..., 0.91, ..., 0.87, ...]), # [L_H] P(NGL)
# },
# "light": { ... same structure ... },
# }
How It Works
- Origin head classifies each position as GL or NGL via
sigmoid(origin_logits). - Positions with
P(NGL) > ngl_thresholdare identified as somatically mutated. - At those positions, the residue is replaced with
argmaxover the 20 GL amino acid logits from the final head. - GL positions are left unchanged.
Paired Auto-Detection with input_ids
When using pre-tokenized input_ids, paired sequences are automatically detected by finding the <cls><cls> separator in the token IDs --- no additional parameters needed:
# Paired: auto-detected from <cls><cls> in input_ids
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
result = model.predict_germline(input_ids=inputs["input_ids"])
# → result["heavy"], result["light"]
# Unpaired: no <cls><cls> → flat output
inputs = tokenizer("EVQLVESGGGLVQ", return_tensors="pt")
result = model.predict_germline(input_ids=inputs["input_ids"])
# → result["sequence"], result["predicted_germline"], ...
Controlling the Threshold
# Aggressive: revert more positions (lower threshold)
result = model.predict_germline(
input_ids=inputs["input_ids"],
ngl_threshold=0.3,
)
# Conservative: only revert high-confidence NGL positions
result = model.predict_germline(
input_ids=inputs["input_ids"],
ngl_threshold=0.8,
)
Batch Processing
inputs = tokenizer(
["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
light_chain=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
return_tensors="pt",
)
results = model.predict_germline(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# results[0] → first pair, results[1] → second pair
String Convenience
# Paired
result = model.predict_germline(
heavy_chains="EVQLVESGGGLVQ",
light_chains="DIQMTQSPSSLSA",
)
# Unpaired
result = model.predict_germline(heavy_chains="EVQLVESGGGLVQ")
generate() --- PLL-Guided Variant Generation
Generates antibody variants using pseudo-log-likelihood guided sampling:
- Collect --- mask each position one at a time, collect pre-gating logits (L forward passes, cached and reusable)
- Select positions --- rank by WT log-probability, sample via Gumbel-Top-k with controllable temperature
- Sample amino acids --- draw from GL, NGL, marginalized, or region-specific logits with temperature, top-k, and nucleus sampling
# Standard: tokenize → generate
inputs = tokenizer(
"EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMS",
light_chain="DIQMTQSPSSLSASVGDRVTITCRASQSISSYLN",
return_tensors="pt",
)
variants = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
n_samples=100, # number of variants to generate
n_mutations=5, # mutations per variant
mode="full", # gl | ngl | full | region_specific
seed=42,
)
# List of 100 dicts:
# [
# {"sequence": "EVQLVE...", "mutations": "S7A,G10D,...", "positions": [6, 9, ...],
# "mode": "full", "n_mut": 5},
# ...
# ]
Sampling Modes
| Mode | Position scoring | AA sampling | Use case |
|---|---|---|---|
"full" |
Marginalized log P(wt) | logsumexp(GL, NGL) logits |
General-purpose diversification |
"gl" |
GL log P(wt) | GL (germline) logits only | Germline reversion / humanization |
"ngl" |
NGL log P(wt) | NGL (non-germline) logits only | Affinity maturation mimicry |
"region_specific" |
FR: GL, CDR: NGL | FR: GL logits, CDR: NGL logits | Targeted: conserve FRs, diversify CDRs |
Controlling Generation
variants = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
n_samples=100,
n_mutations=5,
# --- Position selection ---
pool_size=30, # candidate pool (top-30 worst positions)
position_temperature=0.5, # lower = more deterministic position choice
exclude_positions=np.array([0, 1, 2]), # never mutate these (0-indexed)
# --- Amino acid sampling ---
temperature=0.8, # lower = more conservative AA choices
top_k=10, # only consider top-10 AAs per position
top_p=0.9, # nucleus sampling threshold
# --- Variation ---
randomize_n_mutations=True, # n_mut ~ Beta(2,1) in [1, n_mutations]
seed=42, # reproducibility
)
Region-Specific Mode
The "region_specific" mode uses framework region (FR) and complementarity-determining region (CDR) annotations to apply different sampling strategies: GL logits for FR positions (conserve structure) and NGL logits for CDR positions (diversify binding).
Regions are auto-detected using ANARCI (IMGT numbering). Pass heavy_chain_length so VH and VL are numbered separately:
variants = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
n_samples=100,
n_mutations=5,
mode="region_specific",
heavy_chain_length=len("EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMS"),
)
Or provide region labels manually (0 = FR, 1 = CDR):
import numpy as np
L = len(vh_seq) + len(vl_seq)
region_labels = np.zeros(L, dtype=np.int32)
region_labels[26:34] = 1 # CDR1
region_labels[51:57] = 1 # CDR2
region_labels[93:102] = 1 # CDR3
# ... repeat for VL
variants = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
n_samples=100,
n_mutations=5,
mode="region_specific",
region_labels=region_labels,
)
Caching Masked Logits Across Modes
The most expensive step (L forward passes) can be computed once and reused across different modes:
# First call: collect masked logits + generate
variants_full, cache = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
n_samples=100, n_mutations=5, mode="full", seed=42,
return_masked_data=True,
)
# Subsequent calls: skip L forward passes (instant)
variants_gl = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
n_samples=100, n_mutations=5, mode="gl", seed=42,
masked_data=cache,
)
variants_ngl = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
n_samples=100, n_mutations=5, mode="ngl", seed=42,
masked_data=cache,
)
String Convenience
variants = model.generate(
heavy_chains="EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMS",
light_chains="DIQMTQSPSSLSASVGDRVTITCRASQSISSYLN",
n_samples=100, n_mutations=5, mode="full", seed=42,
)
Reference
forward() Return Dict
| Key | Shape | Description |
|---|---|---|
final_logits |
[L, 53] |
Alpha-gated combined logits (53-vocab) |
aa_logits |
[L, 33] |
AA head logits, before gating |
origin_logits |
[L] |
GL/NGL binary classification logits |
alpha |
[L] |
Per-position gating values |
embedding |
[L, H] |
Per-residue hidden states from backbone |
When paired (string API or auto-detected <cls><cls>), returns {"heavy": {dict}, "light": {dict}}.
predict_germline() Return Dict
| Key | Type | Description |
|---|---|---|
sequence |
str |
Original amino acid sequence |
predicted_germline |
str |
Germline-reverted sequence (NGL positions replaced) |
ngl_positions |
list[int] |
0-indexed positions classified as NGL |
ngl_count |
int |
Number of NGL positions |
ngl_probs |
[L] numpy array |
Per-position P(NGL) from origin head |
When paired, returns {"heavy": {dict}, "light": {dict}} with per-chain values.
Index Constants
model.GL_INDICES--- 20 uppercase AA token indices in the 53-vocabmodel.NGL_INDICES--- 20 lowercase AA token indices in the 53-vocabmodel.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY"--- column order for the 20 AA indices
Part 2: Developer Guide
For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.
Development Installation
git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e ".[dev,analysis]"
Project Structure
prism/
├── src/prism/ # Core Python package
│ ├── __init__.py # Package exports
│ ├── api.py # High-level inference & finetune API
│ ├── tokenizer.py # PrismTokenizer (53-vocab, paired support)
│ ├── model.py # SFT_ESM2 PyTorch Lightning module
│ ├── io_utils.py # Dataset & DataModule classes
│ ├── multimodal_io.py # Gene vocabulary & antibody dataset
│ └── utils.py # Utility functions
│
├── configs/ # Training configuration files
├── script/ # Training, inference, analysis scripts
├── tests/ # Test suite
├── pyproject.toml # Package configuration
└── README.md
Training from Scratch
Two-Stage Training Protocol
Stage 1 --- Pretraining on large unpaired OAS dataset (~60M+ sequences):
python script/train_esm.py --config configs/v34_pretrain.yaml
Stage 2 --- Finetuning on paired antibody sequences (~764K):
python script/train_esm.py --config configs/v34_1b_finetune.yaml
Multi-GPU Training
CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml
Testing
pytest tests/ -v
License
MIT License --- see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file prism_antibody-1.1.1.tar.gz.
File metadata
- Download URL: prism_antibody-1.1.1.tar.gz
- Upload date:
- Size: 180.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
98ababeec757a3307fb4c8b10c719ecdc97b26ecd133cb039c02658cda5c4579
|
|
| MD5 |
f610fa75d6b9d192deb33b13c14c3572
|
|
| BLAKE2b-256 |
5be4c2ae8fb443df56daece827c45fd39d1c440ec71eb8649df401826f24747e
|
File details
Details for the file prism_antibody-1.1.1-py3-none-any.whl.
File metadata
- Download URL: prism_antibody-1.1.1-py3-none-any.whl
- Upload date:
- Size: 107.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
14992c3f1d0b73dd26bf8cb848fae74ee29e27dd3df9c084213ee1d6731bdb30
|
|
| MD5 |
ddeb10990924973556e3117dc6f0d66f
|
|
| BLAKE2b-256 |
f06c7af142284a05fb78f159910dc326f30cd253c6a43bfbb6e1f259475c0f9f
|