Skip to main content

Cryo-EM image simulation and analysis powered by JAX

Project description

cryoJAX

Tests Lint

cryoJAX is a library that provides tools for simulating and analyzing cryo-electron microscopy (cryo-EM) images. It is built on jax.

Summary

Specifically, cryoJAX aims to provide three things in the cryo-EM image-to-structure pipeline.

  1. Physical modeling of image formation
  2. Statistical modeling of the distributions from which images are drawn
  3. Easy-to-use utilities for working with real data

With these tools, cryojax aims to appeal to two different communities. Experimentalists can use cryojax in order to push the boundaries of what they can extract from their data by interfacing with the jax scientific computing ecosystem. Additionally, method developers may use cryojax as a backend for an algorithmic research project, such as in cryo-EM structure determination. These two aims are possible because cryojax is written to be fully interoperable with anything else in the JAX ecosystem.

Dig a little deeper and you'll find that cryojax aims to be a fully extensible modeling language for cryo-EM image formation. It implements a collection of abstract interfaces, which aim to be general enough to support any level of modeling complexity—from simple linear image formation to the most realistic physical models in the field. Best of all, these interfaces are all part of the public API. Users can create their own extensions to cryojax, tailored to their specific use-case!

Documentation

See the documentation at https://mjo22.github.io/cryojax/. It is a work-in-progress, so thank you for your patience!

Installation

Installing cryojax is simple. To start, I recommend creating a new virtual environment. For example, you could do this with conda.

conda create -n cryojax-env -c conda-forge python=3.11

Note that python>=3.10 is required. After creating a new environment, install JAX with either CPU or GPU support. Then, install cryojax. For the latest stable release, install using pip.

python -m pip install cryojax

To install the latest commit, you can build the repository directly.

git clone https://github.com/mjo22/cryojax
cd cryojax
python -m pip install .

The jax-finufft package is an optional dependency used for non-uniform fast fourier transforms. These are included as an option for computing image projections of real-space voxel-based scattering potential representations. In this case, we recommend first following the jax_finufft installation instructions and then installing cryojax.

Simulating an image

The following is a basic workflow to simulate an image.

First, instantiate the spatial potential energy distribution representation and its respective method for computing image projections.

import jax
import jax.numpy as jnp
import cryojax.simulator as cxs
from cryojax.io import read_array_with_spacing_from_mrc

# Instantiate the scattering potential
filename = "example_scattering_potential.mrc"
real_voxel_grid, voxel_size = read_array_with_spacing_from_mrc(filename)
potential = cxs.FourierVoxelGridPotential.from_real_voxel_grid(real_voxel_grid, voxel_size)
# ... now, instantiate the pose. Angles are given in degrees
pose = cxs.EulerAnglePose(
    offset_x_in_angstroms=5.0,
    offset_y_in_angstroms=-3.0,
    view_phi=20.0,
    view_theta=80.0,
    view_psi=-10.0,
)
# ... now, build the ensemble. In this case, the ensemble is just a single structure
structural_ensemble = cxs.SingleStructureEnsemble(potential, pose)

Here, the 3D scattering potential array is read from filename (see the documentation here for an example of how to generate the potential). Then, the abstraction of the scattering potential is then loaded in fourier-space into a FourierVoxelGridPotential, and subsequently the representation of a biological specimen is instantiated, which also includes a pose and conformational heterogeneity. Here, the SingleStructureEnsemble class takes a pose but has no heterogeneity.

Next, build the scattering theory. The simplest scattering_theory is the WeakPhaseScatteringTheory. This represents the usual image formation pipeline in cryo-EM, which forms images by computing projections of the potential and convolving the result with a contrast transfer function.

from cryojax.image import operators as op

# Initialize the scattering theory. First, instantiate fourier slice extraction
potential_integrator = cxs.FourierSliceExtraction(interpolation_order=1)
# ... next, the contrast transfer theory
ctf = cxs.ContrastTransferFunction(
    defocus_in_angstroms=9800.0,
    astigmatism_in_angstroms=200.0,
    astigmatism_angle=10.0,
    amplitude_contrast_ratio=0.1
)
transfer_theory = cxs.ContrastTransferTheory(ctf, envelope=op.FourierGaussian(b_factor=5.0))
# ... now for the scattering theory
scattering_theory = cxs.WeakPhaseScatteringTheory(structural_ensemble, potential_integrator, transfer_theory)

The ContrastTransferFunction has parameters used in CTFFIND4, which take their default values if not explicitly configured here. Finally, we can instantiate the imaging_pipeline--the highest level of imaging abstraction in cryojax--and simulate an image. Here, we choose a ContrastImagingPipeline, which simulates image contrast from a linear scattering theory.

# Finally, build the image formation model
# ... first instantiate the instrument configuration
instrument_config = cxs.InstrumentConfig(shape=(320, 320), pixel_size=voxel_size, voltage_in_kilovolts=300.0)
# ... now the imaging pipeline
imaging_pipeline = cxs.ContrastImagingPipeline(instrument_config, scattering_theory)
# ... finally, simulate an image and return in real-space!
image_without_noise = imaging_pipeline.render(get_real=True)

cryojax also defines a library of distributions from which to sample the data. These distributions define the stochastic model from which images are drawn. For example, instantiate an IndependentGaussianFourierModes distribution and either sample from it or compute its log-likelihood.

from cryojax.image import rfftn, operators as op
from cryojax.inference import distributions as dist

# Passing the ImagePipeline and a variance function, instantiate the distribution
distribution = dist.IndependentGaussianFourierModes(
    imaging_pipeline, variance_function=op.Constant(1.0), is_signal_normalized=True
)
# ... then, either simulate an image from this distribution
key = jax.random.PRNGKey(seed=0)
image_with_noise = distribution.sample(key)
# ... or compute the likelihood
observed = rfftn(...)  # for this example, read in observed data and take FFT
log_likelihood = distribution.log_likelihood(observed)

For more advanced image simulation examples and to understand the many features in this library, see the documentation.

Acknowledgements

  • cryojax has been greatly informed by the open-source cryo-EM softwares cisTEM and BioEM.
  • cryojax relies heavily on and has taken great inspiration from equinox. We think that equinox has great design principles and highly recommend learning about it to fully make use of the power of jax.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cryojax-0.3.3rc1.tar.gz (4.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cryojax-0.3.3rc1-py3-none-any.whl (154.8 kB view details)

Uploaded Python 3

File details

Details for the file cryojax-0.3.3rc1.tar.gz.

File metadata

  • Download URL: cryojax-0.3.3rc1.tar.gz
  • Upload date:
  • Size: 4.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for cryojax-0.3.3rc1.tar.gz
Algorithm Hash digest
SHA256 72e19ad9737708d63851f5bd98bf3763cd872745a286a1a41ea890d70a03706d
MD5 b8e7ea232d8e150a292b51bcdd99394d
BLAKE2b-256 1c5199622899ce10be26f8c35bb81f80753cb8ee35e7a6ed69b2c65790b2ea0e

See more details on using hashes here.

Provenance

The following attestation bundles were made for cryojax-0.3.3rc1.tar.gz:

Publisher: release.yml on mjo22/cryojax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cryojax-0.3.3rc1-py3-none-any.whl.

File metadata

  • Download URL: cryojax-0.3.3rc1-py3-none-any.whl
  • Upload date:
  • Size: 154.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for cryojax-0.3.3rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 4e70b0008fc15c5a4f6d99d9d38f2cd3ab7b273df502a0311064c7f66c307848
MD5 a8b2f682d1aa2a7a374726c2b47e4aac
BLAKE2b-256 00b9d0f2f66cf08e3b7fa7bd368a42e35cd3194c64e266044a0565df808482a1

See more details on using hashes here.

Provenance

The following attestation bundles were made for cryojax-0.3.3rc1-py3-none-any.whl:

Publisher: release.yml on mjo22/cryojax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page