Skip to main content

PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling

Project description

PRISM

Partitioning Residue Identity in Somatic Maturation

This is the official repository for the paper:

Explicit representation of germline and non-germline residues improves antibody language modeling

PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.


Part 1: User Guide

Everything you need to run inference or finetune PRISM on your own antibody data.

Installation

pip install prism-antibody

Or install from source:

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .

Verify Installation

import prism
print(prism.__version__)  # 0.3.0

Quick Start: Inference

import prism

# Load from Hugging Face Hub (auto-downloads and caches)
model = prism.pretrained("RomeroLab-Duke/prism-antibody")

# Or load from a local checkpoint
# model = prism.pretrained("path/to/checkpoint.ckpt")

# Extract germline log-probabilities — [L, 20] numpy array
gl = model.extract_GL_logit("EVQLVESGGGLVQPGGSLRL")

# Extract non-germline log-probabilities — [L, 20] numpy array
ngl = model.extract_NGL_logit("EVQLVESGGGLVQPGGSLRL")

# Marginalized log-probs: logsumexp(GL, NGL) — [L, 20]
marg = model.extract_marginalized_logit("EVQLVESGGGLVQPGGSLRL")

# Full 53-vocab log-probs — [L, 53]
full = model.extract_full_logit("EVQLVESGGGLVQPGGSLRL")

# Alpha gating values (GL/NGL mixture weights) — [L]
alpha = model.extract_alpha("EVQLVESGGGLVQPGGSLRL")

# Mean-pooled embedding — [H]
emb = model.extract_embedding("EVQLVESGGGLVQPGGSLRL")

# Perplexity — scalar
ppl = model.perplexity("EVQLVESGGGLVQPGGSLRL")

All methods accept a single string or a list of strings:

# Batch inference
sequences = ["EVQLVESGGGLVQ", "DIQMTQSPSSLSA", "QVQLVQSGAEVKK"]
embeddings = model.extract_embedding(sequences)  # [3, H] numpy array
ppls = model.perplexity(sequences)                # [3] numpy array

Available Methods

Method Returns Description
extract_GL_logit(seq) [L, 20] Germline (uppercase) log-probabilities
extract_NGL_logit(seq) [L, 20] Non-germline (lowercase) log-probabilities
extract_marginalized_logit(seq) [L, 20] logsumexp(GL, NGL) per amino acid
extract_full_logit(seq) [L, 53] Full vocabulary log-probabilities
extract_alpha(seq) [L] Per-position alpha gating values
extract_embedding(seq) [H] Mean-pooled backbone embedding
embed(seq, mode=...) varies Embeddings ("mean", "per_residue", "cls")
perplexity(seq) scalar Pseudo-perplexity
pseudo_log_likelihood(seq) scalar Pseudo-log-likelihood score
predict_origin(seq) dict GL/NGL origin predictions per residue
score_mutations(wt, mut) scalar Log-likelihood ratio at mutated positions
logits(seq, head=...) tensor Raw logits from any head

Column order for 20-AA outputs follows PrismModel.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY".

Finetuning on Your Data

Finetune PRISM on your own antibody sequences with a single method call:

import prism

model = prism.pretrained("RomeroLab-Duke/prism-antibody")

best_checkpoint = model.finetune(
    data_path="my_antibodies.parquet",   # parquet, pickle, or csv
    output_dir="outputs/my_finetune",
    max_steps=5000,
    learning_rate=1e-4,
    batch_size=32,
)

# Model is now finetuned — use immediately
gl = model.extract_GL_logit("EVQLVESGGGLVQ...")

Data Format

Your data file needs at least one of these columns:

Column Required Description
HEAVY_CHAIN_AA_SEQUENCE At least one Heavy chain amino acid sequence
LIGHT_CHAIN_AA_SEQUENCE At least one Light chain amino acid sequence

Optional columns (auto-detected):

Column Description
split "train" / "valid" / "test" (auto-generated 90/5/5 if absent)
hc_mut_codes, lc_mut_codes NGL mutation codes (e.g. "S30A;T52N")
v_gene_heavy, j_gene_heavy, v_gene_light, j_gene_light Gene labels
region_mask_heavy, region_mask_light Region annotations

Accepted file formats: .parquet, .pkl/.pickle, .csv.

Finetune Parameters

model.finetune(
    data_path="data.parquet",
    output_dir="prism_finetune_output",
    # Training
    max_steps=5000,             # total training steps
    learning_rate=1e-4,         # peak LR (cosine schedule with warmup)
    batch_size=32,              # per-device batch size
    warmup_steps=500,           # linear warmup steps
    weight_decay=0.01,          # AdamW weight decay
    mask_prob=0.15,             # MLM masking probability
    gradient_accumulation_steps=1,
    # Trainer
    devices=1,                  # number of GPUs
    precision="bf16-mixed",     # training precision
    val_check_interval=500,     # validate every N steps
    # Data
    num_workers=4,
    seed=42,
)

Part 2: Developer Guide

For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.

Development Installation

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody

# Install with dev + analysis extras
pip install -e ".[dev,analysis]"

Project Structure

prism/
├── src/prism/                    # Core Python package
│   ├── __init__.py               # Package exports
│   ├── api.py                    # High-level inference & finetune API
│   ├── model.py                  # SFT_ESM2 PyTorch Lightning module
│   ├── io_utils.py               # Dataset & DataModule classes
│   ├── multimodal_io.py          # Gene vocabulary & antibody dataset
│   └── utils.py                  # Utility functions
│
├── configs/                      # Training configuration files
│   ├── v34_pretrain.yaml         # Stage 1: Pretraining config
│   └── v34_1b_finetune.yaml      # Stage 2: Finetuning config
│
├── script/                       # Executable scripts
│   ├── train_esm.py              # Main training script
│   ├── inference_esm.py          # Basic inference
│   ├── inference_esm_with_logprobs.py
│   ├── upload_to_hub.py          # Upload checkpoint to HF Hub
│   │
│   ├── data/                     # Data processing pipeline (1-8)
│   │   ├── 1.processing_and_filtering_*.py
│   │   ├── 2.visualize_unpaired_statistics.py
│   │   ├── 3.filter_by_p90_and_save.py
│   │   ├── 4-6.cluster_*.py
│   │   ├── 7.extract_gene_information.py
│   │   └── 8.extract_data_for_probe.py
│   │
│   └── analyze/                  # Analysis & evaluation scripts
│       ├── 1.pppl_calculation/   # Perplexity calculations
│       ├── 2.gl-ngl_calculation/ # GL/NGL embeddings & linear probes
│       ├── 3.zero-shot/          # Zero-shot prediction benchmarks
│       └── 4.thera-sabdab/       # Controllable generation experiments
│
├── tests/                        # Test suite
│   └── test_api.py               # API tests (59 tests)
│
├── pyproject.toml                # Package configuration
├── requirements.txt              # Explicit dependencies
└── README.md

Core Modules

model.py — SFT_ESM2

The central PyTorch Lightning module. Key components:

  • Base: ESM2 transformer (HuggingFace) with optional SwiGLU activation
  • Multi-head architecture: AA head + Origin head + Alpha gating + Final head
  • Gene conditioning: V/J gene embeddings + region embeddings
  • Loss: Focal loss with region-balanced and CDR-boosted variants

io_utils.py — Data Loading

  • SeqSeqDataset: Handles paired/unpaired antibody sequences with germline reconstruction
  • SFTDataModule: Standard PyTorch Lightning DataModule
  • LazyShardedDataModule: Memory-efficient sharded loading for large datasets
  • make_collate_fn_multihead: Collate with 80/10/10 masking, gene encoding, region IDs

multimodal_io.py — Gene & Region Handling

  • GeneVocabulary: Maps V/J gene strings to integer IDs
  • AntibodyDataset / AntibodyDataCollator / AntibodyMLMCollator: Multi-modal data handling

Training from Scratch

Two-Stage Training Protocol

Stage 1 — Pretraining on large unpaired OAS dataset (~60M+ sequences):

python script/train_esm.py --config configs/v34_pretrain.yaml

Stage 2 — Finetuning on paired antibody sequences (~764K):

# Set pretrained_checkpoint_path in the config to point to Stage 1 best checkpoint
python script/train_esm.py --config configs/v34_1b_finetune.yaml

Multi-GPU Training

CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml

Key Config Options

data:
  data_path: "path/to/data"
  batch_size: 256
  mask_prob: 0.15

model:
  model_identifier: "esm2_t12_35M_UR50D"
  use_multihead_architecture: true
  use_alpha_gating: true
  ngl_loss_alpha: 3.0

training:
  max_steps: 100000
  peak_learning_rate: 4e-4
  warmup_steps: 2000
  gradient_accumulation_steps: 8

trainer:
  devices: 4
  precision: "bf16-mixed"

Analysis Pipeline

Located in script/analyze/, numbered by experiment stage:

Stage Directory Description
1 pppl_calculation/ Pseudo-perplexity comparison across models
2 gl-ngl_calculation/ Embedding extraction, linear probes, UMAP
3 zero-shot/ Binding affinity and developability benchmarks
4 thera-sabdab/ Controllable generation and mutation recovery

Data Processing Pipeline

Located in script/data/, processes OAS (Observed Antibody Space) data:

Step Purpose
1 Filter OAS data, identify NGL positions
2 Visualize data distribution
3 Quality filtering by sequence coverage
4-6 MMseqs2 sequence clustering & deduplication
7 Parse V/J gene annotations
8 Prepare data for linear probe training

Testing

# Run all tests
pytest tests/ -v

# Run specific test class
pytest tests/test_api.py::TestFinetune -v

Linting

black --line-length 100 src/
isort --profile black --line-length 100 src/
flake8 src/
mypy src/

License

MIT License — see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

prism_antibody-0.3.0.tar.gz (78.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

prism_antibody-0.3.0-py3-none-any.whl (70.2 kB view details)

Uploaded Python 3

File details

Details for the file prism_antibody-0.3.0.tar.gz.

File metadata

  • Download URL: prism_antibody-0.3.0.tar.gz
  • Upload date:
  • Size: 78.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-0.3.0.tar.gz
Algorithm Hash digest
SHA256 eff0664c2b65f9f23f57f8fab350671103ef021b02aaeabf3709a16df76541eb
MD5 a4f9d1111712b2e295b674a974addcfc
BLAKE2b-256 d7a0946ba4bcf9391628cc27634b3ecf001fc91e4e9e930dcaabc6b8baa63075

See more details on using hashes here.

File details

Details for the file prism_antibody-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: prism_antibody-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 70.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5d9960b6843f8e840ef379aa45b81f7f7ee8dd0bafd01e40d266b14b3e3e0d05
MD5 a2d56813f6542139ebf4ffc5d8ccd1b7
BLAKE2b-256 7a51afe862ca29f8cda2837631c53c7f2e2b737d66788d03c04495fb765db59c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page