Skip to main content

Integrative Refinement Engine using differentiable forward models

Project description

⚙️ Diff-Integrator: The Integrative Refinement Engine

Tests codecov PyPI version Python versions License: MIT JAX Ruff Type checked: mypy

Diff-Integrator is a JAX-accelerated optimization engine designed for integrative structural biology. It acts as the "orchestrator" that combines differentiable observables from diff-biophys into multi-objective loss functions.

By cleanly separating the optimization loop from the underlying biophysical kernels, diff-integrator enables robust, joint refinement of protein structures against diverse experimental data (e.g., SAXS, NMR Chemical Shifts, NMR RDCs) simultaneously.


🎯 Vision

The goal of diff-integrator is to provide a seamless optax-based refinement pipeline that handles:

  1. Multi-Objective Optimization: Easily weight and combine multiple experimental constraints via JointLoss.
  2. Abstract Parameterization: Optimize arbitrary parameter spaces—from Cartesian coordinates to internal backbone angles (phi/psi)—via user-defined kinematics_fn mappers.
  3. Dynamic Fitting: Analytically refit nuisance parameters (like Saupe alignment tensors or SAXS scaling factors) dynamically during gradient descent.

📚 Interactive Tutorials

Experience Diff-Integrator directly in your browser with our Colab tutorials:

Tutorial Audience Description Action
📉 Results Dashboard Graduate / researcher Visualizes the loss descent, Q-factors, chemical shift accuracy, NeRF drift, and a Cartesian vs. NeRF comparison across all four benchmarks (2KZV, GmR58A, HR2876B NeRF, HR2876B Cartesian). Open In Colab
🧪 Refinement Concepts Student / researcher Educational notebook explaining NMR observables, NeRF coordinate parameterization, RDC tensor degeneracy, and the fixed-tensor protocol. Open In Colab
⚠️ Method Limitations Reviewer / scientist Honest quantitative assessment of the current method's failure modes: NeRF geometric drift, RDC overfitting on PEG data, and degrees-of-freedom imbalance. Open In Colab

⚡ Core Components

IntegrativeRefiner

The core optimization engine. Built on optax (defaulting to the Adam optimizer), it manages the training loop, gradient calculation, and loss tracking.

  • Abstract Support: Accepts arbitrary parameter sets (init_params) and maps them to Cartesian space using an optional kinematics_fn.

JointLoss

A container for combining multiple LossTerm objects. It computes the total weighted loss by evaluating each term on the current parameters and generated coordinates.

LossTerm (Interface)

An abstract base class for defining differentiable constraints. All terms implement __call__(self, params, coords).

Included Observables

  • GeometryLoss: Implements basic structural priors, including harmonic restraints to a target Cartesian structure.
  • SAXSLoss: Dynamically scales and fits theoretical SAXS profiles against experimental data using Debye kernels.
  • FixedTensorRDCLoss: Fixed-tensor RDC loss that holds the Saupe alignment tensor frozen during backpropagation (using jax.lax.stop_gradient) and re-fits it every update_interval epochs. Includes cv_fraction cross-validation split and suggested_weight() auto-scaling by overdetermination ratio.
  • CAShiftLoss: Wraps the ring-current and secondary structure shift predictor to compute $C_\alpha$ chemical shift RMSDs from backbone torsion angles.
  • CartesianCAShiftLoss: Cartesian-space variant that extracts φ/ψ on-the-fly from raw coordinates via compute_phi_psi, enabling chemical shift refinement without a NeRF builder.
  • BondLengthPenalty / BondAnglePenalty: Harmonic restraints on backbone bond lengths and angles to Engh & Huber ideal values. Used in Cartesian refinement to replace the hard geometric constraints of the NeRF builder.
  • RamachandranLoss: Sequence-aware Ramachandran prior with residue-specific Gaussian wells. Handles GLY ε-basin (φ > 0) and PRO ring constraint correctly.
  • NOELoss: Flat-bottomed harmonic NOE distance restraints (standard XPLOR/CNS convention). Accepts atom_pairs, d_upper, and optional d_lower arrays; reports count_violations() and rms_violation() diagnostics. Use make_noe_restraints() to map (res_id, atom_name) pairs directly to atom indices.
  • ChiralityPenalty: Half-harmonic Cα chirality guard for Cartesian refinement. Prevents silent L→D inversion during gradient descent using the signed scalar triple product chi = dot(cross(N−CA, C−CA), C_prev−CA). Use make_backbone_chirality(n_residues) for the standard N–CA–C layout.

🚀 Usage Example

from diff_integrator.loss import JointLoss
from diff_integrator.optimizer import EarlyStopping, IntegrativeRefiner
from diff_integrator.schedules import ExponentialDecaySchedule
from diff_integrator.terms.geometry import GeometryLoss
from diff_integrator.terms.nmr import FixedTensorRDCLoss, make_rdc_cv_refinement_fns

# 1. Build loss terms
geom_term = GeometryLoss(target_coords=starting_coords)

# FixedTensorRDCLoss holds the Saupe tensor fixed during backprop,
# preventing the degeneracy exploit that drives Q→0 unphysically.
loss_fn, q_eval, tensor_fn, val_q_fn, n_train, n_val = make_rdc_cv_refinement_fns(
    rdc_res_ids, exp_rdcs, struct_res_ids, cv_fraction=0.2
)
rdc_term = FixedTensorRDCLoss(
    loss_fn, tensor_fn, update_interval=50,
    n_rdcs=n_train, val_q_eval_fn=val_q_fn
)
# Auto-scale weight by overdetermination ratio (ideal = 10×)
rdc_weight = rdc_term.suggested_weight(base_weight=1.0)

# 2. Combine into a joint loss
joint_loss = JointLoss([
    (geom_term, 5.0),
    (rdc_term, rdc_weight),
])

# 3. Annealed geometry weight: strong early, relaxed late
anchor_schedule = ExponentialDecaySchedule(
    initial_weight=10.0, final_weight=0.1, decay_epochs=100
)

# 4. Refine — result is a RefinementResult dataclass
refiner = IntegrativeRefiner(loss_fn=joint_loss)
result = refiner.run(
    init_params=starting_coords,
    epochs=2000,
    learning_rate=0.005,
    weight_schedules={0: anchor_schedule},      # anneal geometry anchor
    early_stopping=EarlyStopping(              # stop when RDC Q plateaus
        term_index=1, patience=50, min_delta=1e-4
    ),
)
print(f"Best checkpoint: epoch {result.best_epoch}")
print(f"Stopped early:   {result.stopped_early} ({result.early_stopping_triggered_by})")
refined_coords = result.best_params

Multi-phase refinement with freeze_term

Freeze experimental terms for an initial geometry-only phase, then thaw them for full joint refinement — without rebuilding the loss:

from diff_integrator.terms.chirality import make_backbone_chirality
from diff_integrator.terms.noe import make_noe_restraints

chirality_pen = make_backbone_chirality(n_residues)
noe_term      = make_noe_restraints(noe_observations, struct_res_ids)

joint_loss = JointLoss([
    (geom_term,      5.0),   # 0 — position anchor
    (bond_pen,      50.0),   # 1 — bond lengths
    (angle_pen,     10.0),   # 2 — bond angles
    (chirality_pen, 20.0),   # 3 — chirality guard (always on)
    (rdc_term,  rdc_weight), # 4 — RDC
    (noe_term,       5.0),   # 5 — NOE
])

# Phase 1: geometry + chirality only
joint_loss.freeze_term(4)   # freeze RDC
joint_loss.freeze_term(5)   # freeze NOE
result_p1 = refiner.run(init_params=starting_coords, epochs=200)

# Phase 2: add experimental restraints
joint_loss.unfreeze_term(4)
joint_loss.unfreeze_term(5)
result = refiner.run(init_params=result_p1.final_params, epochs=1000)

🔬 Scientific Validation

diff-integrator is being validated against several experimental NMR datasets:

  • 2KZV (CvR118A): Joint refinement using $C_\alpha$ Chemical Shifts and dual-medium (PAG/PEG) RDCs, lowering the $C_\alpha$ RMSD and bringing RDC Q-factors near zero.
  • GmR58A & HR2876B: Successful gradient-based minimization of $C_\alpha$ shift RMSD using internal coordinates (dihedrals).
  • HR2876B Cartesian (Sprint 2): Cartesian + bond-geometry + chirality-guard refinement over 2000 epochs achieved 11× larger Cα RMSD improvement (−0.123 ppm vs −0.011 ppm) and 12× less structural drift (0.545 Å vs 6.4 Å) compared to the NeRF approach. RDC Q-factors dropped 63% on both alignment media (PEG: 0.440→0.163, Pf1: 0.443→0.162). The ChiralityPenalty corrected all 5 D-inverted Cα centers present in the raw PDB model 1 (the pre-Sprint-2 run had silently introduced a 6th).

📂 Repository Structure

diff-integrator/
├── diff_integrator/       # Core package
│   ├── loss.py            # JointLoss and LossTerm interface
│   ├── optimizer.py       # IntegrativeRefiner engine
│   └── terms/             # Concrete loss implementations
│       ├── geometry.py    # Harmonic restraints, RMSD
│       ├── bond_geometry.py # Cartesian bond/angle penalties (Engh & Huber)
│       ├── chirality.py   # Cα chirality penalty (L→D inversion guard)
│       ├── saxs.py        # Debye scattering loss
│       ├── nmr.py         # RDC and Q-factor loss
│       ├── noe.py         # NOE flat-bottomed distance restraints
│       └── chemical_shifts.py # C-alpha shift loss
├── benchmarks/            # Real-world optimization tests
├── tests/                 # Unit tests (100% coverage)
├── docs/                  # MkDocs documentation
└── pyproject.toml         # Build configuration

🚀 Installation

Ensure you have JAX installed, then install diff-integrator locally:

pip install -e .

🤝 Contributing

Contributions are welcome! Please run the test suite and ensure mypy typing passes before submitting PRs:

pytest --cov=diff_integrator
mypy .

⚖️ License

MIT License — see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diff_integrator-0.1.2.tar.gz (66.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diff_integrator-0.1.2-py3-none-any.whl (41.6 kB view details)

Uploaded Python 3

File details

Details for the file diff_integrator-0.1.2.tar.gz.

File metadata

  • Download URL: diff_integrator-0.1.2.tar.gz
  • Upload date:
  • Size: 66.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for diff_integrator-0.1.2.tar.gz
Algorithm Hash digest
SHA256 7aee3f088df071f174b8a542a6beccb5f69d30c804143a52e3bebd3356dcec94
MD5 741aafa8addc20c46cf5e0d9a980d2e1
BLAKE2b-256 36a1ceaf046f44bc356df90052862b004aa204e41f7bef5fbe9b730333c5a37f

See more details on using hashes here.

Provenance

The following attestation bundles were made for diff_integrator-0.1.2.tar.gz:

Publisher: publish.yml on elkins-lab/diff-integrator

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file diff_integrator-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: diff_integrator-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 41.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for diff_integrator-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 acec55545918e6e607358ebb37857b8875a2209891388fe0687a257029efb815
MD5 77a54783504acda2b692a943fdc107e3
BLAKE2b-256 001cd4cf7d87107569e71750c6cede795b818e45127cbbe6f871feb734e7af6a

See more details on using hashes here.

Provenance

The following attestation bundles were made for diff_integrator-0.1.2-py3-none-any.whl:

Publisher: publish.yml on elkins-lab/diff-integrator

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page