Skip to main content

PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling

Project description

PRISM

Partitioning Residue Identity in Somatic Maturation

This is the official repository for the paper:

Explicit representation of germline and non-germline residues improves antibody language modeling

PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.


Part 1: User Guide

Everything you need to run inference or finetune PRISM on your own antibody data.

Installation

pip install prism-antibody

Or install from source:

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .

Verify Installation

import prism
print(prism.__version__)

API Overview

PRISM has 3 core methods that cover all use cases:

Method Cost Returns
forward() 1 forward pass logits, embeddings, origin, alpha
pseudo_log_likelihood() L forward passes PLL, perplexity, per-position log-probs (4 modes)
score_mutations() 2M forward passes masked marginal mutation scores (4 modes)

All methods accept single strings or lists (batch). All support paired (VH+VL) and unpaired input.

forward() --- Logits, Embeddings, Everything

Single forward pass through the model. Returns all outputs as a dict of numpy arrays.

import prism
model = prism.pretrained("RomeroLab-Duke/prism-antibody")

# Paired heavy + light chain (typical usage)
result = model.forward(
    heavy_chains="EVQLVESGGGLVQPGGSLRL",
    light_chains="DIQMTQSPSSLSASVG",
)
# result["heavy"]["final_logits"]  -> [L_vh, 53]  alpha-gated combined logits
# result["heavy"]["aa_logits"]     -> [L_vh, 33]  AA head logits (pre-gating)
# result["heavy"]["origin_logits"] -> [L_vh]      GL/NGL classification logits
# result["heavy"]["alpha"]         -> [L_vh]      gating values
# result["heavy"]["embedding"]     -> [L_vh, H]   per-residue hidden states
# result["light"] has the same keys for the light chain

Derive Any Signal You Need

import numpy as np

# GL/NGL log-probabilities (slice from 53-vocab)
gl_logits  = result["heavy"]["final_logits"][:, model.GL_INDICES]   # [L_vh, 20]
ngl_logits = result["heavy"]["final_logits"][:, model.NGL_INDICES]  # [L_vh, 20]

# Mean-pooled embedding
vh_emb = result["heavy"]["embedding"].mean(axis=0)  # [H]
vl_emb = result["light"]["embedding"].mean(axis=0)  # [H]

# GL/NGL origin probability per residue
origin_prob = 1 / (1 + np.exp(-result["heavy"]["origin_logits"]))  # sigmoid
# origin_prob > 0.5 = predicted NGL (somatic mutation)

Masked Prediction

Mask specific positions to get context-only predictions:

result = model.forward(
    heavy_chains="EVQLVESGGGLVQPGGSLRL",
    light_chains="DIQMTQSPSSLSASVG",
    mask_positions=[5, 10, 15],  # mask these heavy chain positions
)

Batch Inference

results = model.forward(
    heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# List of 2 dicts, each with "heavy" and "light" sub-dicts

Unpaired Input

For heavy-only, light-only, or any single-chain input:

result = model.forward("EVQLVESGGGLVQPGGSLRL")                     # positional arg
result = model.forward(heavy_chains="EVQLVESGGGLVQPGGSLRL")        # explicit heavy
result = model.forward(light_chains="DIQMTQSPSSLSASVG")            # explicit light
# Returns flat dict (no "heavy"/"light" nesting)

pseudo_log_likelihood() --- PLL and Perplexity

Masked marginal scoring: mask each position one at a time, predict, accumulate log P(true token). Returns all 4 scoring modes in one pass with progress bars.

# Paired (typical)
result = model.pseudo_log_likelihood(
    heavy_chains="EVQLVESGGGLVQPGGSLRL",
    light_chains="DIQMTQSPSSLSASVG",
)
# {
#   "marginalized": {"pll": -45.3, "perplexity": 2.34, "per_position": [L_vh + L_vl]},
#   "gl":           {"pll": -50.1, "perplexity": 2.71, "per_position": [L_vh + L_vl]},
#   "ngl":          {"pll": -48.2, "perplexity": 2.56, "per_position": [L_vh + L_vl]},
#   "exact":        {"pll": -50.1, "perplexity": 2.71, "per_position": [L_vh + L_vl]},
# }

# Quick access
ppl = result["marginalized"]["perplexity"]
per_pos = result["gl"]["per_position"]  # per-residue GL log-probs

# Batch paired (with progress bar)
results = model.pseudo_log_likelihood(
    heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)

Scoring Modes

Mode What it scores Use case
marginalized logsumexp(GL, NGL) General-purpose
gl Uppercase (germline) token log-prob Germline likeness
ngl Lowercase (non-germline) token log-prob Somatic mutation preference
exact Raw token log-prob in 53-vocab Direct 53-vocab scoring

Unpaired

result = model.pseudo_log_likelihood("EVQLVESGGGLVQPGGSLRL")

score_mutations() --- Mutation Effect Prediction

Masked marginal scoring at mutation positions. For each mutation, masks that position in both WT and mutant, computes the log-likelihood difference. Returns all 4 modes.

# Paired (typical)
result = model.score_mutations(
    wt="EVQLVESGGGLVQPGGSLRL",
    mutant="EVQLVASGGGLVQPGGSLRL",  # V6A
    wt_light_chains="DIQMTQSPSSLSA",
    mut_light_chains="DIQMTQSPSSLSA",  # same light chain, or different
)
# {
#   "positions": [5],  # 0-indexed mutation positions
#   "marginalized": {"score": 0.42, "per_position": [1]},
#   "gl":           {"score": 0.31, "per_position": [1]},
#   "ngl":          {"score": 0.55, "per_position": [1]},
#   "exact":        {"score": 0.31, "per_position": [1]},
# }
# score > 0 = mutant preferred over WT at that position

# Batch (multiple WT/mutant pairs)
results = model.score_mutations(
    wt=["EVQLVESGG", "DIQMTQSPS"],
    mutant=["EVQLVASGG", "DIQMAQSPS"],
    wt_light_chains=["DIQMT", "EVQLV"],
    mut_light_chains=["DIQMT", "EVQLV"],
)

Unpaired

result = model.score_mutations(wt="ACDEF", mutant="GCKEF")
# result["positions"] == [0, 2]  (A->G, D->K)

Reference

forward() Return Dict

Key Shape Description
final_logits [L, 53] Alpha-gated combined logits (53-vocab)
aa_logits [L, 33] AA head logits, before gating
origin_logits [L] GL/NGL binary classification logits
alpha [L] Per-position gating values
embedding [L, H] Per-residue hidden states from backbone

When paired, returns {"heavy": {dict}, "light": {dict}} with the above keys in each sub-dict.

Index Constants

  • model.GL_INDICES --- 20 uppercase AA token indices in the 53-vocab
  • model.NGL_INDICES --- 20 lowercase AA token indices in the 53-vocab
  • model.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY" --- column order for the 20 AA indices

Input Modes

All 3 methods accept the same input patterns:

# Paired (returns "heavy"/"light" split for forward, combined for PLL/mutations)
model.forward(heavy_chains="VH...", light_chains="VL...")

# Unpaired (positional or keyword)
model.forward("SEQ...")
model.forward(heavy_chains="VH...")
model.forward(light_chains="VL...")

# Batch (list of strings)
model.forward(heavy_chains=["VH1", "VH2"], light_chains=["VL1", "VL2"])

Finetuning on Your Data

import prism

model = prism.pretrained("RomeroLab-Duke/prism-antibody")

best_checkpoint = model.finetune(
    data_path="my_antibodies.parquet",   # parquet, pkl, or csv
    output_dir="outputs/my_finetune",
    max_steps=5000,
    learning_rate=1e-4,
    batch_size=32,
)

# Model is now finetuned --- use immediately
result = model.forward(heavy_chains="EVQLVESGGGLVQ...", light_chains="DIQMT...")

Data Format

Your data file needs at least one of these columns:

Column Required Description
HEAVY_CHAIN_AA_SEQUENCE At least one Heavy chain amino acid sequence
LIGHT_CHAIN_AA_SEQUENCE At least one Light chain amino acid sequence

Optional columns (auto-detected):

Column Description
split "train" / "valid" / "test" (auto-generated 90/5/5 if absent)
hc_mut_codes, lc_mut_codes NGL mutation codes (e.g. "S30A;T52N")
v_gene_heavy, j_gene_heavy, v_gene_light, j_gene_light Gene labels
region_mask_heavy, region_mask_light Region annotations

Finetune Parameters

model.finetune(
    data_path="data.parquet",
    output_dir="prism_finetune_output",
    max_steps=5000,             # total training steps
    learning_rate=1e-4,         # peak LR (cosine schedule with warmup)
    batch_size=32,              # per-device batch size
    warmup_steps=500,           # linear warmup steps
    weight_decay=0.01,          # AdamW weight decay
    mask_prob=0.15,             # MLM masking probability
    gradient_accumulation_steps=1,
    devices=1,                  # number of GPUs
    precision="bf16-mixed",     # training precision
    val_check_interval=500,     # validate every N steps
    num_workers=4,
    seed=42,
)

Part 2: Developer Guide

For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.

Development Installation

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e ".[dev,analysis]"

Project Structure

prism/
├── src/prism/                    # Core Python package
│   ├── __init__.py               # Package exports
│   ├── api.py                    # High-level inference & finetune API
│   ├── model.py                  # SFT_ESM2 PyTorch Lightning module
│   ├── io_utils.py               # Dataset & DataModule classes
│   ├── multimodal_io.py          # Gene vocabulary & antibody dataset
│   └── utils.py                  # Utility functions
│
├── configs/                      # Training configuration files
├── script/                       # Training, inference, analysis scripts
├── tests/                        # Test suite
├── pyproject.toml                # Package configuration
└── README.md

Training from Scratch

Two-Stage Training Protocol

Stage 1 --- Pretraining on large unpaired OAS dataset (~60M+ sequences):

python script/train_esm.py --config configs/v34_pretrain.yaml

Stage 2 --- Finetuning on paired antibody sequences (~764K):

python script/train_esm.py --config configs/v34_1b_finetune.yaml

Multi-GPU Training

CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml

Testing

pytest tests/ -v

License

MIT License --- see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

prism_antibody-0.5.0.tar.gz (122.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

prism_antibody-0.5.0-py3-none-any.whl (80.9 kB view details)

Uploaded Python 3

File details

Details for the file prism_antibody-0.5.0.tar.gz.

File metadata

  • Download URL: prism_antibody-0.5.0.tar.gz
  • Upload date:
  • Size: 122.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-0.5.0.tar.gz
Algorithm Hash digest
SHA256 98a5bb1a9f2f789bbca6800206826cb1afcc430f531e521899a79822a38dbd06
MD5 37ce8bc0d7c9ec2600d7b12c46c976ec
BLAKE2b-256 290e6d5f8c4b91c67c29a72b6659cec5ceac55608779127d50e457dbadbb4e93

See more details on using hashes here.

File details

Details for the file prism_antibody-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: prism_antibody-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 80.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4e1d85d167af32e10b846703aba34c51e2f52d74ef53aef4c87e31d20bbb6de9
MD5 85489cc58241700204efff177381474b
BLAKE2b-256 cc52fea612ddb718d807b7fbddfb9917698ac247857c27ad67f0238ed06269af

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page