Skip to main content

PRISM - Partitioning Residue Identity in Somatic Maturation for antibody language modeling

Project description

PRISM

Partitioning Residue Identity in Somatic Maturation

This is the official repository for the paper:

Explicit representation of germline and non-germline residues improves antibody language modeling

PRISM is a PyTorch Lightning-based framework for supervised fine-tuning of ESM2 protein language models on antibody sequences. It features a multi-head architecture that jointly learns amino acid identity prediction and germline/non-germline (GL/NGL) position classification.


Part 1: User Guide

Everything you need to run inference with PRISM on your own antibody data.

Installation

pip install prism-antibody

Or install from source:

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e .

Verify Installation

import prism
print(prism.__version__)

Quick Start

import prism

model     = prism.pretrained("RomeroLab-Duke/prism-antibody")
tokenizer = model.get_tokenizer()

# Tokenize → model (standard HuggingFace-style pipeline)
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
result = model.forward(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])

Tokenizer

PrismTokenizer wraps the ESM2 tokenizer with PRISM's 53-token vocabulary (33 ESM2 base + 20 lowercase NGL tokens).

tokenizer = prism.PrismTokenizer()          # standalone (no model needed)
tokenizer = model.get_tokenizer()           # or from a loaded model

# Paired heavy + light chain
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
# inputs["input_ids"]       -> [1, L_H+L_L+4]  (CLS + VH + CLS + CLS + VL + EOS)
# inputs["attention_mask"]  -> [1, L_H+L_L+4]

# Batch
inputs = tokenizer(
    ["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chain=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
    return_tensors="pt",
)

# Unpaired (single chain)
inputs = tokenizer("EVQLVESGGGLVQ", return_tensors="pt")

# Encode / decode (paired)
ids = tokenizer.encode_paired("EVQLV", "DIQMT")
heavy, light = tokenizer.decode_paired(ids)    # ("EVQLV", "DIQMT")

# Encode / decode (unpaired)
ids = tokenizer.encode("EVQLV")               # [CLS, E, V, Q, L, V, EOS]
seq = tokenizer.decode(ids)                    # "EVQLV"

NGL-Aware Tokenization

By default, all amino acids are tokenized as uppercase (GL) tokens — this is the standard mode and matches the training format. Use preserve_case=True only when you need exact mode in pseudo_log_likelihood(), which scores each position using its actual GL or NGL log-probability.

# For exact PLL: lowercase = NGL (somatic mutation) positions
inputs = tokenizer("EvQLvESGGglvq", preserve_case=True, return_tensors="pt")
# 'v', 'g', 'l', 'v' → NGL token IDs;  'E', 'Q', 'L', ... → GL token IDs

result = model.pseudo_log_likelihood(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
)
# result["exact"] now uses NGL log-prob at lowercase positions

GL/NGL Token Mappings

tokenizer.gl_token_ids     # {"A": 5, "C": 23, ...}  — 20 uppercase (germline)
tokenizer.ngl_token_ids    # {"a": 33, "c": 34, ...}  — 20 lowercase (non-germline)
tokenizer.gl_to_ngl        # {5: 33, 23: 34, ...}     — GL→NGL token ID mapping
tokenizer.vocab_size       # 53

API Overview

PRISM has 4 core methods. All accept pre-tokenized input_ids (recommended) or raw strings.

Method Cost Returns
forward() 1 forward pass logits, embeddings, origin, alpha
pseudo_log_likelihood() ceil(L / batch_size) forward passes PLL, perplexity, per-position log-probs (4 modes)
score_mutations() 2 × ceil(M / batch_size) forward passes masked marginal mutation scores (4 modes)
generate() L + N forward passes PLL-guided antibody variants

forward() --- Logits, Embeddings, Everything

Single forward pass through the model. Returns all outputs as numpy arrays.

import prism
import numpy as np

model     = prism.pretrained("RomeroLab-Duke/prism-antibody")
tokenizer = model.get_tokenizer()

# Standard: tokenize → forward (paired)
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
result = model.forward(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
# result["final_logits"]  -> [L, 53]  alpha-gated combined logits
# result["aa_logits"]     -> [L, 33]  AA head logits (pre-gating)
# result["origin_logits"] -> [L]      GL/NGL classification logits
# result["alpha"]         -> [L]      gating values
# result["embedding"]     -> [L, H]   per-residue hidden states

# GL/NGL log-probabilities (slice from 53-vocab)
gl_logits  = result["final_logits"][:, model.GL_INDICES]   # [L, 20]
ngl_logits = result["final_logits"][:, model.NGL_INDICES]  # [L, 20]

Batch, Unpaired (string convenience)

# Batch (returns list of {"heavy": {...}, "light": {...}})
results = model.forward(
    heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)

# String convenience (paired)
result = model.forward(heavy_chains="EVQLVESGGGLVQ", light_chains="DIQMTQSPSSLSA")

# Unpaired (single chain)
result = model.forward("EVQLVESGGGLVQPGGSLRL")

pseudo_log_likelihood() --- PLL and Perplexity

Masks each position one at a time, accumulates log P(true token). Returns 4 scoring modes in one pass.

# Standard: tokenize → PLL
inputs = tokenizer("EVQLVESGGGLVQ", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
result = model.pseudo_log_likelihood(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
)
# {
#   "marginalized": {"pll": -45.3, "perplexity": 2.34, "per_position": [L]},
#   "gl":           {"pll": -50.1, "perplexity": 2.71, "per_position": [L]},
#   "ngl":          {"pll": -48.2, "perplexity": 2.56, "per_position": [L]},
#   "exact":        {"pll": -50.1, "perplexity": 2.71, "per_position": [L]},
# }

ppl = result["marginalized"]["perplexity"]

NGL-Aware Scoring with exact Mode

When the input contains NGL tokens (lowercase via preserve_case=True --- see NGL-Aware Tokenization), the exact mode scores each position using its actual token: uppercase log-prob for GL positions, lowercase log-prob for NGL positions.

Scoring Modes

All modes are computed from the 53-vocab alpha-gated logits. gl, ngl, and marginalized extract the GL/NGL slots and combine them back into 20-AA probabilities.

Mode What it scores Use case
marginalized logsumexp(GL, NGL) per AA General-purpose scoring
gl Uppercase (GL) token log-prob Germline likeness
ngl Lowercase (NGL) token log-prob Somatic mutation preference
exact Actual input token log-prob NGL-aware scoring (with preserve_case=True)

Batch Processing

batch_size controls how many masked positions are processed in a single forward pass. Higher values use more GPU memory but run faster.

# Fast: 64 positions per forward pass (needs ~2x memory vs default)
result = model.pseudo_log_likelihood(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    batch_size=64,
)

For multiple sequences, pass a list --- they are scored sequentially, each with the same batch_size parallelism:

# Multiple sequences (processed one at a time, results in order)
results = model.pseudo_log_likelihood(
    heavy_chains=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# results[0] → first pair, results[1] → second pair

String Convenience

# Paired
result = model.pseudo_log_likelihood(
    heavy_chains="EVQLVESGGGLVQ",
    light_chains="DIQMTQSPSSLSA",
)

# Unpaired
result = model.pseudo_log_likelihood("EVQLVESGGGLVQPGGSLRL")

score_mutations() --- Mutation Effect Prediction

Masked marginal scoring at mutation positions. For each mutated position, masks that position in both WT and mutant, runs a forward pass, and computes the log-likelihood difference. Returns all 4 scoring modes.

# Standard: tokenize → score (paired)
wt_inputs  = tokenizer("EVQLVESGGGLVQPGGSLRL", light_chain="DIQMTQSPSSLSA", return_tensors="pt")
mut_inputs = tokenizer("EVQLVASGGGLVQPGGSLRL", light_chain="DIQMTQSPSSLSA", return_tensors="pt")  # V6A
result = model.score_mutations(
    wt_input_ids=wt_inputs["input_ids"],
    mut_input_ids=mut_inputs["input_ids"],
)
# {
#   "positions": [5],  # 0-indexed mutation positions (detected from token diff)
#   "marginalized": {"score": 0.42, "per_position": [1]},
#   "gl":           {"score": 0.31, "per_position": [1]},
#   "ngl":          {"score": 0.55, "per_position": [1]},
#   "exact":        {"score": 0.31, "per_position": [1]},
# }
# score > 0 = mutant preferred over WT

Batch Processing

batch_size controls how many mutation positions are masked per forward pass. For sequences with many mutations, higher values are faster.

result = model.score_mutations(
    wt_input_ids=wt_inputs["input_ids"],
    mut_input_ids=mut_inputs["input_ids"],
    batch_size=64,
)

For multiple WT/mutant pairs, pass lists --- they are scored sequentially:

results = model.score_mutations(
    wt=["EVQLVESGGGLVQ", "QVQLVQSGAEVKK"],
    mutant=["EVQLVASGGGLVQ", "QVQLVQSGAEVAK"],
    wt_light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
    mut_light_chains=["DIQMTQSPSSLSA", "EIVLTQSPGTLSL"],
)
# results[0] → first pair, results[1] → second pair

String Convenience

# Paired
result = model.score_mutations(
    wt="EVQLVESGGGLVQPGGSLRL",
    mutant="EVQLVASGGGLVQPGGSLRL",
    wt_light_chains="DIQMTQSPSSLSA",
    mut_light_chains="DIQMTQSPSSLSA",
)

# Unpaired
result = model.score_mutations(
    wt="EVQLVESGGGLVQPGGSLRL",
    mutant="EVQLVASGGGLVQPGGSLRL",
)

generate() --- PLL-Guided Variant Generation

Generates antibody variants using pseudo-log-likelihood guided sampling:

  1. Collect --- mask each position one at a time, collect pre-gating logits (L forward passes, cached and reusable)
  2. Select positions --- rank by WT log-probability, sample via Gumbel-Top-k with controllable temperature
  3. Sample amino acids --- draw from GL, NGL, marginalized, or region-specific logits with temperature, top-k, and nucleus sampling
# Standard: tokenize → generate
inputs = tokenizer(
    "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMS",
    light_chain="DIQMTQSPSSLSASVGDRVTITCRASQSISSYLN",
    return_tensors="pt",
)

variants = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    n_samples=100,                # number of variants to generate
    n_mutations=5,                # mutations per variant
    mode="full",                  # gl | ngl | full | region_specific
    seed=42,
)
# List of 100 dicts:
# [
#   {"sequence": "EVQLVE...", "mutations": "S7A,G10D,...", "positions": [6, 9, ...],
#    "mode": "full", "n_mut": 5},
#   ...
# ]

Sampling Modes

Mode Position scoring AA sampling Use case
"full" Marginalized log P(wt) logsumexp(GL, NGL) logits General-purpose diversification
"gl" GL log P(wt) GL (germline) logits only Germline reversion / humanization
"ngl" NGL log P(wt) NGL (non-germline) logits only Affinity maturation mimicry
"region_specific" FR: GL, CDR: NGL FR: GL logits, CDR: NGL logits Targeted: conserve FRs, diversify CDRs

Controlling Generation

variants = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    n_samples=100,
    n_mutations=5,

    # --- Position selection ---
    pool_size=30,                 # candidate pool (top-30 worst positions)
    position_temperature=0.5,     # lower = more deterministic position choice
    exclude_positions=np.array([0, 1, 2]),  # never mutate these (0-indexed)

    # --- Amino acid sampling ---
    temperature=0.8,              # lower = more conservative AA choices
    top_k=10,                     # only consider top-10 AAs per position
    top_p=0.9,                    # nucleus sampling threshold

    # --- Variation ---
    randomize_n_mutations=True,   # n_mut ~ Beta(2,1) in [1, n_mutations]
    seed=42,                      # reproducibility
)

Region-Specific Mode

The "region_specific" mode uses framework region (FR) and complementarity-determining region (CDR) annotations to apply different sampling strategies: GL logits for FR positions (conserve structure) and NGL logits for CDR positions (diversify binding).

Regions are auto-detected using ANARCI (IMGT numbering). Pass heavy_chain_length so VH and VL are numbered separately:

variants = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    n_samples=100,
    n_mutations=5,
    mode="region_specific",
    heavy_chain_length=len("EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMS"),
)

Or provide region labels manually (0 = FR, 1 = CDR):

import numpy as np
L = len(vh_seq) + len(vl_seq)
region_labels = np.zeros(L, dtype=np.int32)
region_labels[26:34] = 1   # CDR1
region_labels[51:57] = 1   # CDR2
region_labels[93:102] = 1  # CDR3
# ... repeat for VL

variants = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    n_samples=100,
    n_mutations=5,
    mode="region_specific",
    region_labels=region_labels,
)

Caching Masked Logits Across Modes

The most expensive step (L forward passes) can be computed once and reused across different modes:

# First call: collect masked logits + generate
variants_full, cache = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    n_samples=100, n_mutations=5, mode="full", seed=42,
    return_masked_data=True,
)

# Subsequent calls: skip L forward passes (instant)
variants_gl = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    n_samples=100, n_mutations=5, mode="gl", seed=42,
    masked_data=cache,
)

variants_ngl = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    n_samples=100, n_mutations=5, mode="ngl", seed=42,
    masked_data=cache,
)

String Convenience

variants = model.generate(
    heavy_chains="EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMS",
    light_chains="DIQMTQSPSSLSASVGDRVTITCRASQSISSYLN",
    n_samples=100, n_mutations=5, mode="full", seed=42,
)

Reference

forward() Return Dict

Key Shape Description
final_logits [L, 53] Alpha-gated combined logits (53-vocab)
aa_logits [L, 33] AA head logits, before gating
origin_logits [L] GL/NGL binary classification logits
alpha [L] Per-position gating values
embedding [L, H] Per-residue hidden states from backbone

When paired (string API), returns {"heavy": {dict}, "light": {dict}}.

Index Constants

  • model.GL_INDICES --- 20 uppercase AA token indices in the 53-vocab
  • model.NGL_INDICES --- 20 lowercase AA token indices in the 53-vocab
  • model.AA_ORDER = "ACDEFGHIKLMNPQRSTVWY" --- column order for the 20 AA indices

Part 2: Developer Guide

For researchers and developers who want to train from scratch, run analysis pipelines, or extend the codebase.

Development Installation

git clone https://github.com/RomeroLab-Duke/prism-antibody.git
cd prism-antibody
pip install -e ".[dev,analysis]"

Project Structure

prism/
├── src/prism/                    # Core Python package
│   ├── __init__.py               # Package exports
│   ├── api.py                    # High-level inference & finetune API
│   ├── tokenizer.py              # PrismTokenizer (53-vocab, paired support)
│   ├── model.py                  # SFT_ESM2 PyTorch Lightning module
│   ├── io_utils.py               # Dataset & DataModule classes
│   ├── multimodal_io.py          # Gene vocabulary & antibody dataset
│   └── utils.py                  # Utility functions
│
├── configs/                      # Training configuration files
├── script/                       # Training, inference, analysis scripts
├── tests/                        # Test suite
├── pyproject.toml                # Package configuration
└── README.md

Training from Scratch

Two-Stage Training Protocol

Stage 1 --- Pretraining on large unpaired OAS dataset (~60M+ sequences):

python script/train_esm.py --config configs/v34_pretrain.yaml

Stage 2 --- Finetuning on paired antibody sequences (~764K):

python script/train_esm.py --config configs/v34_1b_finetune.yaml

Multi-GPU Training

CUDA_VISIBLE_DEVICES=0,1,2,3 python script/train_esm.py --config configs/v34_pretrain.yaml

Testing

pytest tests/ -v

License

MIT License --- see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

prism_antibody-1.0.0.tar.gz (155.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

prism_antibody-1.0.0-py3-none-any.whl (94.5 kB view details)

Uploaded Python 3

File details

Details for the file prism_antibody-1.0.0.tar.gz.

File metadata

  • Download URL: prism_antibody-1.0.0.tar.gz
  • Upload date:
  • Size: 155.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-1.0.0.tar.gz
Algorithm Hash digest
SHA256 5a01c6ff48f151cbec495920c04ca744c95839c277a705f84aff3b8538317fc1
MD5 79be51a493148eb2903e90e714d38ebf
BLAKE2b-256 8c75dc81960ecf9f12987ae5849b39be0572fea2dbbc6ee867bffac4d4843005

See more details on using hashes here.

File details

Details for the file prism_antibody-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: prism_antibody-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 94.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for prism_antibody-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a90d0be20947dcc3f2f9abe6b3dd38dd15d3c1709c4c38dae38363bd7281d744
MD5 fb506c9c78c85b6aebbb46fa6c8573e6
BLAKE2b-256 b3f663371fe4c1e3ed01aecca643571a62f123e8fac7a7a14694c72e519117c7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page