Skip to main content

Chemical reaction networks in JAX: GPU-parallel stochastic simulations.

Project description

crn-jax

CI PyPI Python License: MIT Ruff

Chemical reaction networks in JAX — a tiny, GPU-optimized Gillespie / Stochastic Simulation Algorithm (SSA) library.

Birth-death benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.   Linear-cascade benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.

Wall-time to simulate 1,000,000 independent stochastic trajectories — each a full Gillespie run of the reaction network from t=0 to t=20, sampled at 200 time points (CPU vs RTX 5090 GPU).

Install

pip install crn-jax

# with NVIDIA GPU support:
pip install crn-jax "jax[cuda12]"

# with plotting helpers:
pip install "crn-jax[examples]"

# for local development (uses Poetry):
git clone https://github.com/robinhenry/crn-jax && cd crn-jax
poetry install                    # main deps + dev tools
poetry install --with gpu         # add jax[cuda12] on an NVIDIA host

crn-jax depends on jax / jaxlib only.

Key features

  • 🎯 Exact SSA — pure-JAX implementation of the Gillespie algorithm for chemical reaction networks.
  • JIT-compiled — the entire loop compiles under jax.jit.
  • 🚀 GPU speedup — 1M+ independent trajectories on a single GPU under jax.vmap, with no Python overhead.
  • ⏱️ Discretization-safe — pending reaction times are preserved across simulation-interval boundaries, so trajectories are physically correct under discrete observations (or fixed-interval stepping).
  • 🎛️ Control-input aware — propensities take an optional input argument that can vary per-interval and per-replicate, so each of N parallel trajectories can follow its own control schedule (useful for RL-style rollouts, closed-loop experiments with per-replicate inputs, …).
  • 🎨 Pre-built motifs - a series of standard motifs implemented for convenience.
  • 🧩 Bring-your-own state — the loop operates on any PyTree (NamedTuple, Flax struct dataclass, Equinox module, …).

Quickstart

A 1-species birth-death process, ∅ → X at rate λ and X → ∅ at rate μ·x, simulated for 100 independent replicates:

from typing import NamedTuple
import jax, jax.numpy as jnp
from crn_jax import simulate_trajectory, plot_trajectories

BIRTH_RATE, DEATH_RATE = 3.0, 0.1    # steady-state mean λ/μ = 30

# Define a state-holding object
class State(NamedTuple):
    time: jax.Array
    x: jax.Array
    next_reaction_time: jax.Array    # carried across intervals

# Return propensity equations as an array
# with an optional external input (unused here)
def propensities(s, _input):
    return jnp.array([BIRTH_RATE, DEATH_RATE * s.x])

# Describe how the state changes when reaction `j` fires
def apply_reaction(s, j):
    return s._replace(x=s.x + jnp.where(j == 0, 1.0, -1.0))

# Initial state
state0 = State(jnp.array(0.0), jnp.array(0.0), jnp.array(jnp.inf))

@jax.jit
@jax.vmap
def run_one(key):
    return simulate_trajectory(
        key=key,
        initial_state=state0,
        timestep=1.0,
        n_steps=200,
        # Pass our 2 custom functions defined above
        compute_propensities_fn=propensities,
        apply_reaction_fn=apply_reaction,
    )

# Simulate 100 Gillespie trajectories
states = run_one(jax.random.split(jax.random.PRNGKey(0), 100))
times = jnp.arange(1, 201) * 1.0

See the examples folder for more detailed examples.

Standard motifs

crn_jax.motifs provides pre-built canonical reaction networks. Each motif exports a uniform surface: State, Params, propensities_fn(), apply_reaction(), plus a one-call simulate_dataset(), so generating time series from different systems is a one-line change:

import jax
from crn_jax.motifs import cascade

ds = cascade.simulate_dataset(jax.random.PRNGKey(0))

# Access X, Y, u, dX, dY observations
ds.X_t, ds.Y_t, ds.u_per_triple, ds.dX, ds.dY

The primitive functions (propensities_fn(), apply_reaction()) also plug into simulate_trajectory directly when the convenience helper simulate_dataset isn't enough (e.g., if you need custom u schedules, specific initial condition mixtures, etc.).

motif reactions input shape
inducible 2 yes Hill-modulated birth-death
autoreg 2 no negative autoregulation (Hill repressor)
cascade 4 yes u → X → Y two-stage cascade
ffl_and 6 yes C1 feed-forward loop with AND output gate

API

# Main entry point: scan n_steps fixed-length intervals, stack the per-step states.
from crn_jax import simulate_trajectory

# Finer control: one interval at a time (RL-style), or until an absolute time.
from crn_jax.gillespie import simulate_interval, simulate_until

# Plotting helper: step-plots a single trajectory or an (N, T) ensemble.
from crn_jax import plot_trajectories

# Optional kinetic-law helpers.
from crn_jax.kinetics import hill_function, sample_lognormal
function when to reach for it
simulate_trajectory You want a full trajectory on a fixed sampling grid. Start here.
simulate_interval You're driving the system yourself, one step at a time (e.g. an RL rollout).
simulate_until You need a custom state shape or a non-uniform time grid. Fully generic.
plot_trajectories Quick look at the output.

See Also

  • GillesPy2: C++ optimized Gillespie simulations on CPU.
  • jax-smfsb: JAX implementations of algorithms from the Stochastic Modelling for Systems Biology book.
  • myriad-jax: RL-style decision making fully in JAX, powered by grn-jax at its core.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crn_jax-0.2.tar.gz (19.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crn_jax-0.2-py3-none-any.whl (24.6 kB view details)

Uploaded Python 3

File details

Details for the file crn_jax-0.2.tar.gz.

File metadata

  • Download URL: crn_jax-0.2.tar.gz
  • Upload date:
  • Size: 19.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for crn_jax-0.2.tar.gz
Algorithm Hash digest
SHA256 36a1c88daafeab083a82ca7d9fcec3a899a13505c4cfb2da6c425c259e005de4
MD5 e5f63a05007e78ecf8cb000f30887a51
BLAKE2b-256 c6dd72ec021e136e55876983ad27b76981870dcd3811a365a2b9c114af9c2e11

See more details on using hashes here.

Provenance

The following attestation bundles were made for crn_jax-0.2.tar.gz:

Publisher: release.yml on robinhenry/crn-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file crn_jax-0.2-py3-none-any.whl.

File metadata

  • Download URL: crn_jax-0.2-py3-none-any.whl
  • Upload date:
  • Size: 24.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for crn_jax-0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 76c2ac2134e4fbbf77428e8c4adc09f01f32c31450a1cb4f94c14d7e31a27fa6
MD5 936c7461d4b33e4a9d47fc1e4d4e2f7e
BLAKE2b-256 7500be28a6ae89d4ad70da22589967b1deb5ba6564b70e5d8ceb079b850a9324

See more details on using hashes here.

Provenance

The following attestation bundles were made for crn_jax-0.2-py3-none-any.whl:

Publisher: release.yml on robinhenry/crn-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page