Skip to main content

PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling

Project description

PRISM

Partitioning Residue Identity in Somatic Maturation

This is the official repository for the paper:

Explicit representation of germline and non-germline residues improves antibody language modeling

PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.


Part 1: User Guide

Everything you need to run inference or finetune PRISM on your own antibody data.

Installation

pip install prism-antibody

Or install from source:

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .

Verify Installation

import prism
print(prism.__version__)  # 0.4.1

Quick Start: Inference

import prism

# Load from Hugging Face Hub (auto-downloads and caches)
model = prism.pretrained("RomeroLab-Duke/prism-antibody")

# Or load from a local checkpoint
# model = prism.pretrained("path/to/checkpoint.ckpt")

# Extract germline log-probabilities — [L, 20] numpy array
gl = model.extract_GL_logit("EVQLVESGGGLVQPGGSLRL")

# Extract non-germline log-probabilities — [L, 20] numpy array
ngl = model.extract_NGL_logit("EVQLVESGGGLVQPGGSLRL")

# Marginalized log-probs: logsumexp(GL, NGL) — [L, 20]
marg = model.extract_marginalized_logit("EVQLVESGGGLVQPGGSLRL")

# Full 53-vocab log-probs — [L, 53]
full = model.extract_full_logit("EVQLVESGGGLVQPGGSLRL")

# Alpha gating values (GL/NGL mixture weights) — [L]
alpha = model.extract_alpha("EVQLVESGGGLVQPGGSLRL")

# Mean-pooled embedding — [H]
emb = model.extract_embedding("EVQLVESGGGLVQPGGSLRL")

# Perplexity — scalar
ppl = model.perplexity("EVQLVESGGGLVQPGGSLRL")

All methods accept a single string or a list of strings:

# Batch inference
sequences = ["EVQLVESGGGLVQ", "DIQMTQSPSSLSA", "QVQLVQSGAEVKK"]
embeddings = model.extract_embedding(sequences)  # [3, H] numpy array
ppls = model.perplexity(sequences)                # [3] numpy array

Paired Heavy + Light Chain Input (v0.4.2)

All methods support paired VH+VL input via the light_chains parameter. The API handles the <cls><cls> separator and tokenization automatically:

# Heavy chain only (default)
gl = model.extract_GL_logit("EVQLVESGGGLVQPGGSLRL")

# Paired heavy + light chain
gl = model.extract_GL_logit(
    "EVQLVESGGGLVQPGGSLRL",           # heavy chain
    light_chains="DIQMTQSPSSLSASVG",  # light chain
)

# Batch paired inference
gl = model.extract_GL_logit(
    ["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)

# Works with all methods: extract_*, perplexity, embed, predict_origin, etc.
ppl = model.perplexity("EVQLVESGGGLVQ", light_chains="DIQMTQSPSSLSA")
emb = model.extract_embedding("EVQLVESGGGLVQ", light_chains="DIQMTQSPSSLSA")
origin = model.predict_origin("EVQLVESGGGLVQ", light_chains="DIQMTQSPSSLSA")

For score_binding, use wt_light_chain and mut_light_chains:

score = model.score_binding(
    "EVQLVESGGGLVQPGGSLRL",           # WT heavy chain
    "EVQLVASGGGLVQPGGSLRL",           # mutant heavy chain
    wt_light_chain="DIQMTQSPSSLSA",   # WT light chain
    mut_light_chains="DIQMTQSPSSLSA", # mutant light chain (same VL)
)

Binding Affinity Scoring (v0.4.0)

Score how mutations affect binding affinity using PRISM's origin head signals. Validated on FLAb2 DMS benchmarks where it outperforms ESM2-650M, AbLang2, AntiBERTy, and Sapiens.

# Single mutant
score = model.score_binding(
    "EVQLVESGGGLVQPGGSLRL",   # wild-type
    "EVQLVASGGGLVQPGGSLRL",   # mutant (V6A)
)
# score > 0 → predicted to maintain/improve binding

# Multiple mutants
scores = model.score_binding(
    "EVQLVESGGGLVQPGGSLRL",
    ["EVQLVASGGGLVQPGGSLRL", "EVQLVESGGALVQPGGSLRL", "EVQLVESGGGLVQPGASLRL"],
)
# scores: [3] numpy array

# Alternative signals: origin_logit (default), origin_prob, ngl_logprob, alpha
# Alternative aggregations: sum (default), mean
score = model.score_binding(wt, mutant, signal="ngl_logprob", aggregation="mean")

Developability Scoring (v0.4.0)

Score 6 biophysical properties using per-property optimal methods from an exhaustive evaluation of 432 zero-shot scoring methods.

dev = model.score_developability("EVQLVESGGGLVQPGGSLRL")
# {
#   'self_interaction': 0.82,     # fraction of low-NGL positions (higher = better)
#   'hydrophobicity': 0.12,       # min alpha gating (higher = better)
#   'thermal_stability': 1.45,    # std of GL log-probs (higher = better)
#   'immunogenicity': -2.33,      # mean GL log-prob (higher = better)
#   'polyreactivity': 0.67,       # std of alpha (lower = better)
#   'expression': -1.89,          # mean NGL log-prob (higher = better)
# }

# Score specific properties only
dev = model.score_developability("EVQLV...", properties=["immunogenicity", "expression"])

# Batch scoring
devs = model.score_developability(["EVQLV...", "DIQMT..."])  # list of dicts

Available Methods

Method Returns Description
extract_GL_logit(seq) [L, 20] Germline (uppercase) log-probabilities
extract_NGL_logit(seq) [L, 20] Non-germline (lowercase) log-probabilities
extract_marginalized_logit(seq) [L, 20] logsumexp(GL, NGL) per amino acid
extract_full_logit(seq) [L, 53] Full vocabulary log-probabilities
extract_alpha(seq) [L] Per-position alpha gating values
extract_embedding(seq) [H] Mean-pooled backbone embedding
embed(seq, mode=...) varies Embeddings ("mean", "per_residue", "cls")
perplexity(seq) scalar Pseudo-perplexity
pseudo_log_likelihood(seq) scalar Pseudo-log-likelihood score
predict_origin(seq) dict GL/NGL origin predictions per residue
score_mutations(wt, mut) scalar Log-likelihood ratio at mutated positions
score_binding(wt, mut) scalar Binding affinity change (origin head, WT-centered)
score_developability(seq) dict 6 biophysical property scores
logits(seq, head=...) tensor Raw logits from any head

All methods accept an optional light_chains= parameter for paired VH+VL input (v0.4.2). score_binding uses wt_light_chain= and mut_light_chains= instead. Column order for 20-AA outputs follows PrismModel.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY".

Finetuning on Your Data

Finetune PRISM on your own antibody sequences with a single method call:

import prism

model = prism.pretrained("RomeroLab-Duke/prism-antibody")

best_checkpoint = model.finetune(
    data_path="my_antibodies.parquet",   # parquet, pickle, or csv
    output_dir="outputs/my_finetune",
    max_steps=5000,
    learning_rate=1e-4,
    batch_size=32,
)

# Model is now finetuned — use immediately
gl = model.extract_GL_logit("EVQLVESGGGLVQ...")

Data Format

Your data file needs at least one of these columns:

Column Required Description
HEAVY_CHAIN_AA_SEQUENCE At least one Heavy chain amino acid sequence
LIGHT_CHAIN_AA_SEQUENCE At least one Light chain amino acid sequence

Optional columns (auto-detected):

Column Description
split "train" / "valid" / "test" (auto-generated 90/5/5 if absent)
hc_mut_codes, lc_mut_codes NGL mutation codes (e.g. "S30A;T52N")
v_gene_heavy, j_gene_heavy, v_gene_light, j_gene_light Gene labels
region_mask_heavy, region_mask_light Region annotations

Accepted file formats: .parquet, .pkl/.pickle, .csv.

Finetune Parameters

model.finetune(
    data_path="data.parquet",
    output_dir="prism_finetune_output",
    # Training
    max_steps=5000,             # total training steps
    learning_rate=1e-4,         # peak LR (cosine schedule with warmup)
    batch_size=32,              # per-device batch size
    warmup_steps=500,           # linear warmup steps
    weight_decay=0.01,          # AdamW weight decay
    mask_prob=0.15,             # MLM masking probability
    gradient_accumulation_steps=1,
    # Trainer
    devices=1,                  # number of GPUs
    precision="bf16-mixed",     # training precision
    val_check_interval=500,     # validate every N steps
    # Data
    num_workers=4,
    seed=42,
)

Part 2: Developer Guide

For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.

Development Installation

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody

# Install with dev + analysis extras
pip install -e ".[dev,analysis]"

Project Structure

prism/
├── src/prism/                    # Core Python package
│   ├── __init__.py               # Package exports
│   ├── api.py                    # High-level inference & finetune API
│   ├── model.py                  # SFT_ESM2 PyTorch Lightning module
│   ├── io_utils.py               # Dataset & DataModule classes
│   ├── multimodal_io.py          # Gene vocabulary & antibody dataset
│   └── utils.py                  # Utility functions
│
├── configs/                      # Training configuration files
│   ├── v34_pretrain.yaml         # Stage 1: Pretraining config
│   └── v34_1b_finetune.yaml      # Stage 2: Finetuning config
│
├── script/                       # Executable scripts
│   ├── train_esm.py              # Main training script
│   ├── inference_esm.py          # Basic inference
│   ├── inference_esm_with_logprobs.py
│   ├── upload_to_hub.py          # Upload checkpoint to HF Hub
│   │
│   ├── data/                     # Data processing pipeline (1-8)
│   │   ├── 1.processing_and_filtering_*.py
│   │   ├── 2.visualize_unpaired_statistics.py
│   │   ├── 3.filter_by_p90_and_save.py
│   │   ├── 4-6.cluster_*.py
│   │   ├── 7.extract_gene_information.py
│   │   └── 8.extract_data_for_probe.py
│   │
│   └── analyze/                  # Analysis & evaluation scripts
│       ├── 1.pppl_calculation/   # Perplexity calculations
│       ├── 2.gl-ngl_calculation/ # GL/NGL embeddings & linear probes
│       ├── 3.zero-shot/          # Zero-shot prediction benchmarks
│       └── 4.thera-sabdab/       # Controllable generation experiments
│
├── tests/                        # Test suite
│   └── test_api.py               # API tests (59 tests)
│
├── pyproject.toml                # Package configuration
├── requirements.txt              # Explicit dependencies
└── README.md

Core Modules

model.py — SFT_ESM2

The central PyTorch Lightning module. Key components:

  • Base: ESM2 transformer (HuggingFace) with optional SwiGLU activation
  • Multi-head architecture: AA head + Origin head + Alpha gating + Final head
  • Gene conditioning: V/J gene embeddings + region embeddings
  • Loss: Focal loss with region-balanced and CDR-boosted variants

io_utils.py — Data Loading

  • SeqSeqDataset: Handles paired/unpaired antibody sequences with germline reconstruction
  • SFTDataModule: Standard PyTorch Lightning DataModule
  • LazyShardedDataModule: Memory-efficient sharded loading for large datasets
  • make_collate_fn_multihead: Collate with 80/10/10 masking, gene encoding, region IDs

multimodal_io.py — Gene & Region Handling

  • GeneVocabulary: Maps V/J gene strings to integer IDs
  • AntibodyDataset / AntibodyDataCollator / AntibodyMLMCollator: Multi-modal data handling

Training from Scratch

Two-Stage Training Protocol

Stage 1 — Pretraining on large unpaired OAS dataset (~60M+ sequences):

python script/train_esm.py --config configs/v34_pretrain.yaml

Stage 2 — Finetuning on paired antibody sequences (~764K):

# Set pretrained_checkpoint_path in the config to point to Stage 1 best checkpoint
python script/train_esm.py --config configs/v34_1b_finetune.yaml

Multi-GPU Training

CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml

Key Config Options

data:
  data_path: "path/to/data"
  batch_size: 256
  mask_prob: 0.15

model:
  model_identifier: "esm2_t12_35M_UR50D"
  use_multihead_architecture: true
  use_alpha_gating: true
  ngl_loss_alpha: 3.0

training:
  max_steps: 100000
  peak_learning_rate: 4e-4
  warmup_steps: 2000
  gradient_accumulation_steps: 8

trainer:
  devices: 4
  precision: "bf16-mixed"

Analysis Pipeline

Located in script/analyze/, numbered by experiment stage:

Stage Directory Description
1 pppl_calculation/ Pseudo-perplexity comparison across models
2 gl-ngl_calculation/ Embedding extraction, linear probes, UMAP
3 zero-shot/ Binding affinity and developability benchmarks
4 thera-sabdab/ Controllable generation and mutation recovery

Data Processing Pipeline

Located in script/data/, processes OAS (Observed Antibody Space) data:

Step Purpose
1 Filter OAS data, identify NGL positions
2 Visualize data distribution
3 Quality filtering by sequence coverage
4-6 MMseqs2 sequence clustering & deduplication
7 Parse V/J gene annotations
8 Prepare data for linear probe training

Testing

# Run all tests
pytest tests/ -v

# Run specific test class
pytest tests/test_api.py::TestFinetune -v

Linting

black --line-length 100 src/
isort --profile black --line-length 100 src/
flake8 src/
mypy src/

License

MIT License — see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

prism_antibody-0.4.2.tar.gz (130.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

prism_antibody-0.4.2-py3-none-any.whl (86.4 kB view details)

Uploaded Python 3

File details

Details for the file prism_antibody-0.4.2.tar.gz.

File metadata

  • Download URL: prism_antibody-0.4.2.tar.gz
  • Upload date:
  • Size: 130.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-0.4.2.tar.gz
Algorithm Hash digest
SHA256 a657b4b759a59a5e918419f1b3858ae3dc556cce2d10b964f30fe349c4704fe6
MD5 7b88a72b2d15836998f28e5584e26efd
BLAKE2b-256 65bf9caf23f391ccff55e22f68f0a3b2745730566a30fcb1bfc19b6918fcf39c

See more details on using hashes here.

File details

Details for the file prism_antibody-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: prism_antibody-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 86.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1560fd1b782e3ec518a4c1acb9c3e85b8d9061a290917789c25b75f5fc92e565
MD5 0ea3a0ad40de8d5d92db311d2799b379
BLAKE2b-256 d446084bc215634a90fd1bb7f16b9ec459ea4c2d7f31940a9bc4640ddecac93c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page