Chemical reaction networks in JAX: GPU-parallel stochastic simulations.
Project description
crn-jax
Chemical reaction networks in JAX — a tiny, GPU-optimized Gillespie / Stochastic Simulation Algorithm (SSA) library.
Wall-time to simulate 1,000,000 independent stochastic trajectories — each a full Gillespie run of the reaction network from t=0 to t=20, sampled at 200 time points (CPU vs RTX 5090 GPU).
Install
pip install crn-jax
# with NVIDIA GPU support:
pip install crn-jax "jax[cuda12]"
# with plotting helpers:
pip install "crn-jax[examples]"
# for local development (uses Poetry):
git clone https://github.com/robinhenry/crn-jax && cd crn-jax
poetry install # main deps + dev tools
poetry install --with gpu # add jax[cuda12] on an NVIDIA host
crn-jax depends on jax / jaxlib only.
Key features
- 🎯 Exact SSA — pure-JAX implementation of the Gillespie algorithm for chemical reaction networks.
- ⚡ JIT-compiled — the entire loop compiles under
jax.jit. - 🚀 GPU speedup — 1M+ independent trajectories on a single GPU under
jax.vmap, with no Python overhead. - ⏱️ Discretization-safe — pending reaction times are preserved across simulation-interval boundaries, so trajectories are physically correct under discrete observations (or fixed-interval stepping).
- 🎛️ Control-input aware — propensities take an optional
inputargument that can vary per-interval and per-replicate, so each of N parallel trajectories can follow its own control schedule (useful for RL-style rollouts, closed-loop experiments with per-replicate inputs, …). - 🎨 Pre-built motifs - a series of standard motifs implemented for convenience.
- 🧩 Bring-your-own state — the loop operates on any PyTree (NamedTuple, Flax struct dataclass, Equinox module, …).
Quickstart
A 1-species birth-death process, ∅ → X at rate λ and X → ∅ at rate μ·x, simulated for 100 independent replicates:
from typing import NamedTuple
import jax, jax.numpy as jnp
from crn_jax import simulate_trajectory, plot_trajectories
BIRTH_RATE, DEATH_RATE = 3.0, 0.1 # steady-state mean λ/μ = 30
# Define a state-holding object
class State(NamedTuple):
time: jax.Array
x: jax.Array
next_reaction_time: jax.Array # carried across intervals
# Return propensity equations as an array
# with an optional external input (unused here)
def propensities(s, _input):
return jnp.array([BIRTH_RATE, DEATH_RATE * s.x])
# Describe how the state changes when reaction `j` fires
def apply_reaction(s, j):
return s._replace(x=s.x + jnp.where(j == 0, 1.0, -1.0))
# Initial state
state0 = State(jnp.array(0.0), jnp.array(0.0), jnp.array(jnp.inf))
@jax.jit
@jax.vmap
def run_one(key):
return simulate_trajectory(
key=key,
initial_state=state0,
timestep=1.0,
n_steps=200,
# Pass our 2 custom functions defined above
compute_propensities_fn=propensities,
apply_reaction_fn=apply_reaction,
)
# Simulate 100 Gillespie trajectories
states = run_one(jax.random.split(jax.random.PRNGKey(0), 100))
times = jnp.arange(1, 201) * 1.0
See the examples folder for more detailed examples.
Standard motifs
crn_jax.motifs provides pre-built canonical reaction networks. Each motif exports a uniform surface: State, Params, propensities_fn(), apply_reaction(), plus a one-call simulate_dataset(), so generating time series from different systems is a one-line change:
import jax
from crn_jax.motifs import cascade
ds = cascade.simulate_dataset(jax.random.PRNGKey(0))
# Access X, Y, u, dX, dY observations
ds.X_t, ds.Y_t, ds.u_per_triple, ds.dX, ds.dY
The primitive functions (propensities_fn(), apply_reaction()) also plug into simulate_trajectory directly when the convenience helper simulate_dataset isn't enough (e.g., if you need custom u schedules, specific initial condition mixtures, etc.).
| motif | reactions | input | shape |
|---|---|---|---|
inducible |
2 | yes | Hill-modulated birth-death |
autoreg |
2 | no | negative autoregulation (Hill repressor) |
cascade |
4 | yes | u → X → Y two-stage cascade |
ffl_and |
6 | yes | C1 feed-forward loop with AND output gate |
API
# Main entry point: scan n_steps fixed-length intervals, stack the per-step states.
from crn_jax import simulate_trajectory
# Finer control: one interval at a time (RL-style), or until an absolute time.
from crn_jax.gillespie import simulate_interval, simulate_until
# Plotting helper: step-plots a single trajectory or an (N, T) ensemble.
from crn_jax import plot_trajectories
# Optional kinetic-law helpers.
from crn_jax.kinetics import hill_function, sample_lognormal
| function | when to reach for it |
|---|---|
simulate_trajectory |
You want a full trajectory on a fixed sampling grid. Start here. |
simulate_interval |
You're driving the system yourself, one step at a time (e.g. an RL rollout). |
simulate_until |
You need a custom state shape or a non-uniform time grid. Fully generic. |
plot_trajectories |
Quick look at the output. |
See Also
- GillesPy2: C++ optimized Gillespie simulations on CPU.
- jax-smfsb: JAX implementations of algorithms from the Stochastic Modelling for Systems Biology book.
- myriad-jax: RL-style decision making fully in JAX, powered by
grn-jaxat its core.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file crn_jax-0.2.tar.gz.
File metadata
- Download URL: crn_jax-0.2.tar.gz
- Upload date:
- Size: 19.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
36a1c88daafeab083a82ca7d9fcec3a899a13505c4cfb2da6c425c259e005de4
|
|
| MD5 |
e5f63a05007e78ecf8cb000f30887a51
|
|
| BLAKE2b-256 |
c6dd72ec021e136e55876983ad27b76981870dcd3811a365a2b9c114af9c2e11
|
Provenance
The following attestation bundles were made for crn_jax-0.2.tar.gz:
Publisher:
release.yml on robinhenry/crn-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
crn_jax-0.2.tar.gz -
Subject digest:
36a1c88daafeab083a82ca7d9fcec3a899a13505c4cfb2da6c425c259e005de4 - Sigstore transparency entry: 1418051383
- Sigstore integration time:
-
Permalink:
robinhenry/crn-jax@6861a19735dc003ffbf6b8a7e2f7fe8048c8cfe5 -
Branch / Tag:
refs/tags/v0.2 - Owner: https://github.com/robinhenry
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@6861a19735dc003ffbf6b8a7e2f7fe8048c8cfe5 -
Trigger Event:
push
-
Statement type:
File details
Details for the file crn_jax-0.2-py3-none-any.whl.
File metadata
- Download URL: crn_jax-0.2-py3-none-any.whl
- Upload date:
- Size: 24.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
76c2ac2134e4fbbf77428e8c4adc09f01f32c31450a1cb4f94c14d7e31a27fa6
|
|
| MD5 |
936c7461d4b33e4a9d47fc1e4d4e2f7e
|
|
| BLAKE2b-256 |
7500be28a6ae89d4ad70da22589967b1deb5ba6564b70e5d8ceb079b850a9324
|
Provenance
The following attestation bundles were made for crn_jax-0.2-py3-none-any.whl:
Publisher:
release.yml on robinhenry/crn-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
crn_jax-0.2-py3-none-any.whl -
Subject digest:
76c2ac2134e4fbbf77428e8c4adc09f01f32c31450a1cb4f94c14d7e31a27fa6 - Sigstore transparency entry: 1418051386
- Sigstore integration time:
-
Permalink:
robinhenry/crn-jax@6861a19735dc003ffbf6b8a7e2f7fe8048c8cfe5 -
Branch / Tag:
refs/tags/v0.2 - Owner: https://github.com/robinhenry
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@6861a19735dc003ffbf6b8a7e2f7fe8048c8cfe5 -
Trigger Event:
push
-
Statement type: