Skip to main content

JAX implementation of Protein Strain Analysis (per-site weighted finite strain, PSA)

Project description

psax

JAX implementation of Protein Strain Analysis (PSA): per-site weighted finite strain and deformation gradients, consistent with the reference implementation Sartori-Lab/PSA and the original paper:

Pablo Sartori and Stanislas Leibler, Evolutionary conservation of mechanical strain distributions in functional transitions of protein structures, Phys. Rev. X 14, 011042 (2024). DOI: 10.1103/PhysRevX.14.011042 (APS: link.aps.org).

BibTeX and preprint links: references/sartori_leibler_2024_prx.md.

Install

pip install psax
# editable dev
pip install -e ".[dev]"
# optional: parity tests against upstream PSA (install VCS pin separately — not a PyPI extra)
pip install "psa @ git+https://github.com/Sartori-Lab/PSA.git@a55f44eea3c8165d618cfd607f1c3cebe7535cbb"

Requires Python 3.11+.

Structure I/O (proxide)

proxide is not on PyPI. To run psax run or run_pairwise_psa_from_structures, install proxide from a local checkout in the same environment:

pip install /path/to/proxide

If this repo lives in a uv workspace next to a local proxide member, add a workspace source mapping (see uv workspace sources):

[tool.uv.sources]
proxide = { workspace = true }

PyPI and releases

Workflow Purpose
.github/workflows/ci.yml Tests, docs, upstream parity, wheel smoke on main
.github/workflows/publish.yml Build with python -m build and upload to PyPI on GitHub Release (OIDC trusted publishing)

Configure trusted publishing on PyPI for this repository and workflow, and add a GitHub Environment named pypi if you use protection rules. Tag releases (e.g. v0.1.0) and publish a GitHub Release to trigger the workflow (or run it manually via workflow_dispatch).

StableHLO / jax.export (AOT)

For deployment and compiler toolchains, psax exposes a thin export layer in psax.stablehlo: fixed static shapes (N, 3) and (E,) directed edges, with method / ridge_eps / rcond fixed at export time. Serialized artifacts use JAX’s export format (VHLO / calling-convention version) and are tied to the JAX version you built with—see the JAX export documentation and the OpenXLA StableHLO + JAX tutorial.

CLI:

psax emit-stablehlo --n 128 --edges 4096 --out psa_strain.bin

(run_pairwise_psa is not the export entrypoint: it builds weights outside of jit.)

Batching ref / design pairs (vmap)

  • Aligned batchB independent pairs (x_ref[b], x_def[b]) with the same graph (shared edges_i, edges_j, edge_weights): use psax.batch.deformation_gradient_per_site_vmap_pairs or run_pairwise_psa_batched_shared_graph when the dense weight mask (N, N) is shared.
  • Ref × design grid — all combinations (b_ref, b_design) with fixed shared edges: psax.batch.deformation_gradient_per_site_grid_ref_design.
  • Per-sample edges — same batch size and edge count E: deformation_gradient_per_site_batched_edges / edge_batch_mode="per_pair".
  • Different E per sample needs padding/masking or a Python loop with psax.utils.safe_map (see docs / bucket placeholder).

Quickstart (two coordinate sets → per-site F → strain)

import jax.numpy as jnp
from psax.run.pipeline import run_pairwise_psa

x_ref = jnp.array(...)  # (N, 3)
x_def = jnp.array(...)  # (N, 3)
out = run_pairwise_psa(x_ref, x_def, r_inner=6.0, r_outer=8.0)
F = out.deformation_gradient       # (N, 3, 3)
E = out.green_lagrange_strain      # (N, 3, 3)
lam = out.principal_strain_eigenvalues  # (N, 3)

With structures on disk (requires proxide):

from psax.run.pipeline import run_pairwise_psa_from_structures

out = run_pairwise_psa_from_structures("ref.pdb", "def.pdb", align_kabsch=False)

CLI

psax --version   # or -V
psax version
psax run --help
psax emit-stablehlo --help
psax run ref.pdb def.pdb -o out.npz --bfactors-out colored.pdb --template-pdb ref.pdb
psax export template.pdb out.pdb --values "1.0,2.0,3.0"

The CLI dispatches to psax.run.pipeline, psax.stablehlo, and psax.io.export.

Parity & limitations

What is validated How
JAX vs in-repo NumPy PSA loop tests/test_core_parity.py (parity_numpy)
JAX vs installed psa.elastic.deformation_gradient (dense weights, fp64) tests/parity_upstream/ (parity_upstream), install psa from git (see Install)
Symmetrized D + solvers vs NumPy reference Covered indirectly when ridge_eps=0 matches upstream dense path

Not fully cross-checked here: upstream sparse/Numba dict fast paths, energy/rotation pipelines in PSA, or every PSA strain helper name-for-name—only the dense deformation-gradient path and Green–Lagrange strain built from F.

See tests/parity_upstream/README.md for dense vs sparse scope.

Documentation (Sphinx)

pip install -e ".[docs]"
sphinx-build -b html docs docs/_build/html

Testing & coverage

pytest
pytest --cov=psax --cov-report=term-missing

Markers: unit, parity_numpy, parity_upstream, structure_io, slow (see pyproject.toml).

API sketch

  • Build a directed edge list from a dense weight matrix outside jit (psax.graph.edges_from_dense_weight_matrix).
  • deformation_gradient_per_site: (\mathbf{F}_i) from weighted bond vectors using jax.ops.segment_sum, with method= solve / lstsq / svd and optional return_diagnostics (condition numbers).
  • Strain / kinematics (psax.core): Green–Lagrange and small-strain Lagrange, Euler strains, Cauchy–Green invariants, principal stretches, rotation axis/angle (polar decomposition).
  • Energy (psax.core.energy): Saint Venant–Kirchhoff density and unit helpers.
  • Alignment (psax.alignment): vendored soft Smith–Waterman / Needleman–Wunsch (sequence), Kabsch rigid superposition (structure).
  • Spatial (psax.spatial): cylindrical/spherical coordinates, tensor rotation into cylindrical frame, rough effective volumes.
  • I/O (psax.io): load_structure via proxide (optional), B-factor export without BioPython.
  • AOT (psax.stablehlo): jax.export of the PSA strain pipeline at fixed shapes.
  • Synthetic tests (psax.testing.forms): rods, twist/spin/radial deformations.

Legacy global least-squares (\mathbf{F}) (from the prxteinmpnn snapshot) lives under old_psa for regression tests only; it is not the published per-site PSA in Phys. Rev. X 14, 011042 (2024).

JAX notes

  • safe_map (psax.utils.safe_map): when batching many independent structures, pass a Python int batch_size into safe_map (not a traced value). Inside jit, prefer fixed-shape kernels and bucketing (see psax.bucket) instead of compiling over Python loops.
  • Float precision: enable jax_enable_x64 if you need to match double-precision NumPy/SciPy baselines (tests set this in conftest.py).
  • Eigenvectors: principal_strain_eigensystem fixes eigenvector signs deterministically (no random hemisphere flips).
  • RNG: core PSA is deterministic; if you add stochastic workflows, derive subkeys with jax.random.fold_in(key, index).

Prior art

The old_psa package is adapted from prxteinmpnn’s PSA helpers (global (\mathbf{F}) and dense linear weights). The per-site method implemented in psax.core follows Sartori-Lab/PSA and Sartori & Leibler, Phys. Rev. X 14, 011042 (2024).

Citation

If you use this software in published work, please cite the PSA paper (and this codebase if appropriate):

@article{PhysRevX.14.011042,
  title = {Evolutionary Conservation of Mechanical Strain Distributions in Functional Transitions of Protein Structures},
  author = {Sartori, Pablo and Leibler, Stanislas},
  journal = {Phys. Rev. X},
  volume = {14},
  issue = {1},
  pages = {011042},
  year = {2024},
  publisher = {American Physical Society},
  doi = {10.1103/PhysRevX.14.011042},
  url = {https://link.aps.org/doi/10.1103/PhysRevX.14.011042}
}

Reference implementation: https://github.com/Sartori-Lab/PSA.
More detail: references/sartori_leibler_2024_prx.md.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

psax-0.1.1.tar.gz (43.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

psax-0.1.1-py3-none-any.whl (42.7 kB view details)

Uploaded Python 3

File details

Details for the file psax-0.1.1.tar.gz.

File metadata

  • Download URL: psax-0.1.1.tar.gz
  • Upload date:
  • Size: 43.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for psax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 deb8f090a8bccdd25ec8b1699ebee5dd3082392e8b0018e89ddad0d906f06f09
MD5 6d518c76e1e2d51c3c00f29de68dd87d
BLAKE2b-256 98d66b55091c698132ba44e50b485fb88146b1549343fd1d2ded044093dddcca

See more details on using hashes here.

Provenance

The following attestation bundles were made for psax-0.1.1.tar.gz:

Publisher: publish.yml on maraxen/psax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file psax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: psax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 42.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for psax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8c648df25fc8cac619de1cbd053f1b44e8963b50ca6475381b7e28b0b3d5166f
MD5 97c503d41dc12e6b29b9f317d5970021
BLAKE2b-256 43a3c49ff151bf76c77369a9cf9dd7d7007fada1dbc9b096b3656f327d93b918

See more details on using hashes here.

Provenance

The following attestation bundles were made for psax-0.1.1-py3-none-any.whl:

Publisher: publish.yml on maraxen/psax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page