Cryo-EM image simulation and analysis powered by JAX
Project description
cryoJAX
Summary
CryoJAX is a library that simulates cryo-electron microscopy (cryo-EM) images in JAX. Its purpose is to provide the tools for building downstream data analysis in external workflows and libraries that leverage the statistical inference and machine learning resources of the JAX scientific computing ecosystem. To achieve this, image simulation in cryoJAX is built for reliability and flexibility; it implements a variety of established models and algorithms as well as a framework for implementing new models and algorithms downstream. If your application uses cryo-EM image simulation and it cannot be built downstream, open a pull request.
Documentation
See the documentation at https://michael-0brien.github.io/cryojax/. It is a work-in-progress, so thank you for your patience!
Installation
Installing cryojax is simple. To start, I recommend creating a new virtual environment. For example, you could do this with conda.
conda create -n cryojax-env -c conda-forge python=3.11
Note that python>=3.10 is required. After creating a new environment, install JAX with either CPU or GPU support. Then, install cryojax. For the latest stable release, install using pip.
python -m pip install cryojax
To install the latest commit, you can build the repository directly.
git clone https://github.com/michael-0brien/cryojax
cd cryojax
python -m pip install .
The jax-finufft package is an optional dependency used for non-uniform fast fourier transforms. These are included as an option for computing image projections. In this case, we recommend first following the jax_finufft installation instructions and then installing cryojax.
Simulating an image
The following is a basic workflow to simulate an image.
import jax
import jax.numpy as jnp
import cryojax.simulator as cxs
from cryojax.io import read_array_from_mrc
# Instantiate the voxel grid representation of a volume. See the documentation
# for how to generate voxel grids from a PDB
filename = "example_volume.mrc"
real_voxel_grid, voxel_size = read_array_from_mrc(filename, loads_grid_spacing=True)
volume = cxs.FourierVoxelGridVolume.from_real_voxel_grid(real_voxel_grid)
# The pose. Angles are given in degrees.
pose = cxs.EulerAnglePose(
offset_x_in_angstroms=5.0,
offset_y_in_angstroms=-3.0,
phi_angle=20.0,
theta_angle=80.0,
psi_angle=-10.0,
)
# The model for the CTF
ctf = cxs.AstigmaticCTF(
defocus_in_angstroms=9800.0, astigmatism_in_angstroms=200.0, astigmatism_angle=10.0
)
transfer_theory = cxs.ContrastTransferTheory(ctf, amplitude_contrast_ratio=0.1)
# The image configuration
image_config = cxs.BasicImageConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
# Instantiate a cryoJAX `image_model` using the `make_image_model` function
image_model = cxs.make_image_model(volume, image_config, pose, transfer_theory)
# Simulate an image
image = image_model.simulate(outputs_real_space=True)
For more advanced image simulation examples and to understand the many features in this library, see the documentation.
JAX transformations
CryoJAX is built on JAX to make use of JIT-compilation, automatic differentiation, and vectorization for cryo-EM data analysis. JAX implements these operations as function transformations. If you aren't familiar with this concept, see the JAX documentation.
Below are examples of implementing these transformations using equinox, a popular JAX library for PyTorch-like classes that smoothly integrate with JAX functional programming. To learn more about how equinox assists with JAX transformations, see here.
Your first JIT compiled function
import equinox as eqx
# Define image simulation function using `equinox.filter_jit`
@eqx.filter_jit
def simulate_fn(image_model):
"""Simulate an image with JIT compilation"""
return image_model.simulate()
# Simulate an image
image = simulate_fn(image_model)
Computing gradients of a loss function
import equinox as eqx
import jax
import jax.numpy as jnp
# Load observed data
observed_image = ...
# Split the `image_model` by differentiated and non-differentiated
# arguments. Here, differentiate with respect to the pose.
is_pose = lambda x: isinstance(x, cxs.AbstractPose)
filter_spec = jax.tree.map(is_pose, image_model, is_leaf=is_pose)
model_grad, model_nograd = eqx.partition(image_model, filter_spec)
@eqx.filter_grad
def gradient_fn(model_grad, model_nograd, observed_image):
"""Compute gradients with respect to the pose."""
image_model = eqx.combine(model_grad, model_nograd)
return jnp.sum((image_model.simulate() - observed_image)**2)
# Compute gradients
gradients = gradient_fn(model_grad, model_nograd, observed_image)
Vectorizing image simulation
import equinox as eqx
# Vectorize model instantiation
@eqx.filter_vmap(in_axes=(0, None, None, None), out_axes=(eqx.if_array(0), None))
def make_image_model_vmap(wxyz, volume, image_config, transfer_theory):
pose = cxs.QuaternionPose(wxyz=wxyz)
image_model = cxs.make_image_model(
volume, image_config, pose, transfer_theory, normalizes_signal=True
)
is_pose = lambda x: isinstance(x, cxs.AbstractPose)
filter_spec = jax.tree.map(is_pose, image_model, is_leaf=is_pose)
model_vmap, model_novmap = eqx.partition(image_model, filter_spec)
return model_vmap, model_novmap
# Define image simulation function
@eqx.filter_vmap(in_axes=(eqx.if_array(0), None))
def simulate_fn_vmap(model_vmap, model_novmap):
image_model = eqx.combine(model_vmap, model_novmap)
return image_model.simulate()
# Batch image simulation over poses
wxyz = ... # ... load quaternions
model_vmap, model_novmap = make_image_model_vmap(wxyz, volume, image_config, transfer_theory)
images = simulate_fn_vmap(model_vmap, model_novmap)
Acknowledgements
cryojaximplementations of several models and algorithms, such as the CTF, fourier slice extraction, and electrostatic potential computations has been informed by the open-source cryo-EM softwarecisTEM.cryojaxis built usingequinox, a popular JAX library for PyTorch-like classes that smoothly integrate with JAX functional programming. We highly recommend learning aboutequinoxto fully make use of the power ofjax.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cryojax-0.5.0rc4.tar.gz.
File metadata
- Download URL: cryojax-0.5.0rc4.tar.gz
- Upload date:
- Size: 3.9 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
08af30ccf6660e9a8a35dadff49b0c7a9bc2a49b04928fb31881ee049be8dff7
|
|
| MD5 |
5ce7a754ad9f7499658f823572b1473c
|
|
| BLAKE2b-256 |
72b8437efe82d9cd92e0f99bc8994196541f7ae9b94efe24cc4b9c2eaced8dc8
|
Provenance
The following attestation bundles were made for cryojax-0.5.0rc4.tar.gz:
Publisher:
release.yml on michael-0brien/cryojax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cryojax-0.5.0rc4.tar.gz -
Subject digest:
08af30ccf6660e9a8a35dadff49b0c7a9bc2a49b04928fb31881ee049be8dff7 - Sigstore transparency entry: 590188410
- Sigstore integration time:
-
Permalink:
michael-0brien/cryojax@98c017fc405ec3faaa64cb4bb36346ae15834148 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/michael-0brien
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@98c017fc405ec3faaa64cb4bb36346ae15834148 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file cryojax-0.5.0rc4-py3-none-any.whl.
File metadata
- Download URL: cryojax-0.5.0rc4-py3-none-any.whl
- Upload date:
- Size: 176.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ee9273ea05fe8506fb1607d5f93c4e31539c4a930000ae106fdc62abc32f241e
|
|
| MD5 |
cb8d3c659a24bea1f328a49dcbc0e5d8
|
|
| BLAKE2b-256 |
8b780c73ba3a7929671dc8bb17d82f6ffb91a56c423d268986048ca9cf3021cc
|
Provenance
The following attestation bundles were made for cryojax-0.5.0rc4-py3-none-any.whl:
Publisher:
release.yml on michael-0brien/cryojax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cryojax-0.5.0rc4-py3-none-any.whl -
Subject digest:
ee9273ea05fe8506fb1607d5f93c4e31539c4a930000ae106fdc62abc32f241e - Sigstore transparency entry: 590188594
- Sigstore integration time:
-
Permalink:
michael-0brien/cryojax@98c017fc405ec3faaa64cb4bb36346ae15834148 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/michael-0brien
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@98c017fc405ec3faaa64cb4bb36346ae15834148 -
Trigger Event:
workflow_dispatch
-
Statement type: