Chemical reaction networks in JAX: GPU-parallel stochastic simulations.
Project description
crn-jax
Chemical reaction networks in JAX — a tiny, GPU-optimized Gillespie / Stochastic Simulation Algorithm (SSA) library.
Wall-time to simulate 1,000,000 independent stochastic trajectories — each a full Gillespie run of the reaction network from t=0 to t=20, sampled at 200 time points (CPU vs RTX 5090 GPU).
Install
pip install crn-jax
# with NVIDIA GPU support:
pip install crn-jax "jax[cuda12]"
# with plotting helpers:
pip install "crn-jax[examples]"
# for local development (uses Poetry):
git clone https://github.com/robinhenry/crn-jax && cd crn-jax
poetry install # main deps + dev tools
poetry install --with gpu # add jax[cuda12] on an NVIDIA host
crn-jax depends on jax / jaxlib only.
Key features
- 🎯 Exact SSA — pure-JAX implementation of the Gillespie algorithm for chemical reaction networks.
- ⚡ JIT-compiled — the entire loop compiles under
jax.jit. - 🚀 GPU speedup — 1M+ independent trajectories on a single GPU under
jax.vmap, with no Python overhead. - ⏱️ Discretization-safe — pending reaction times are preserved across simulation-interval boundaries, so trajectories are physically correct under discrete observations (or fixed-interval stepping).
- 🎛️ Control-input aware — propensities take an optional
inputargument that can vary per-interval and per-replicate, so each of N parallel trajectories can follow its own control schedule (useful for RL-style rollouts, closed-loop experiments with per-replicate inputs, …). - 🧩 Bring-your-own state — the loop operates on any PyTree (NamedTuple, Flax struct dataclass, Equinox module, …).
Quickstart
A 1-species birth-death process, ∅ → X at rate λ and X → ∅ at rate μ·x, simulated for 10 independent replicates and plotted:
from typing import NamedTuple
import jax, jax.numpy as jnp
from crn_jax import simulate_trajectory, plot_trajectories
BIRTH_RATE, DEATH_RATE = 3.0, 0.1 # steady-state mean λ/μ = 30
# Define a state-holding object
class State(NamedTuple):
time: jax.Array
x: jax.Array
next_reaction_time: jax.Array # carried across intervals
# Return propensity equations as an array
# with an optional external input (unused here)
def propensities(s, _input):
return jnp.array([BIRTH_RATE, DEATH_RATE * s.x])
# Describe how the state changes when reaction `j` fires
def apply_reaction(s, j):
return s._replace(x=s.x + jnp.where(j == 0, 1.0, -1.0))
# Initial state
state0 = State(jnp.array(0.0), jnp.array(0.0), jnp.array(jnp.inf))
@jax.jit
@jax.vmap
def run_one(key):
return simulate_trajectory(
key=key,
initial_state=state0,
timestep=1.0,
n_steps=200,
# Pass our 2 custom functions defined above
compute_propensities_fn=propensities,
apply_reaction_fn=apply_reaction,
)
# Simulate 100 Gillespie trajectories
states = run_one(jax.random.split(jax.random.PRNGKey(0), 100))
times = jnp.arange(1, 201) * 1.0
See the examples folder for more detailed examples.
API
# Main entry point: scan n_steps fixed-length intervals, stack the per-step states.
from crn_jax import simulate_trajectory
# Finer control: one interval at a time (RL-style), or until an absolute time.
from crn_jax.gillespie import simulate_interval, simulate_until
# Plotting helper: step-plots a single trajectory or an (N, T) ensemble.
from crn_jax import plot_trajectories
# Optional kinetic-law helpers.
from crn_jax.kinetics import hill_function, sample_lognormal
| function | when to reach for it |
|---|---|
simulate_trajectory |
You want a full trajectory on a fixed sampling grid. Start here. |
simulate_interval |
You're driving the system yourself, one step at a time (e.g. an RL rollout). |
simulate_until |
You need a custom state shape or a non-uniform time grid. Fully generic. |
plot_trajectories |
Quick look at the output. |
See Also
- GillesPy2: C++ optimized Gillespie simulations on CPU.
- myriad-jax: RL-style decision making fully in JAX, powered by
grn-jaxat its core.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file crn_jax-0.1.2.tar.gz.
File metadata
- Download URL: crn_jax-0.1.2.tar.gz
- Upload date:
- Size: 11.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f82b04a5aeff869395220a65b9acca4f9589c07fcc905e268b2ac355de790612
|
|
| MD5 |
79ccd32d9a3824027e156571cf7c6bcf
|
|
| BLAKE2b-256 |
7cb2c833fbf74d98441fb69c4592c4d6a1ce4233dfb963cd78d5819fb61755b3
|
Provenance
The following attestation bundles were made for crn_jax-0.1.2.tar.gz:
Publisher:
release.yml on robinhenry/crn-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
crn_jax-0.1.2.tar.gz -
Subject digest:
f82b04a5aeff869395220a65b9acca4f9589c07fcc905e268b2ac355de790612 - Sigstore transparency entry: 1361707053
- Sigstore integration time:
-
Permalink:
robinhenry/crn-jax@05a8a234047370070df0a56dffce9be40a467cb0 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/robinhenry
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@05a8a234047370070df0a56dffce9be40a467cb0 -
Trigger Event:
push
-
Statement type:
File details
Details for the file crn_jax-0.1.2-py3-none-any.whl.
File metadata
- Download URL: crn_jax-0.1.2-py3-none-any.whl
- Upload date:
- Size: 11.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a4245cece09d5180cc1109d49e2678970c364ac46ef9d47aad9009949e8afd2d
|
|
| MD5 |
e49b51a971b45f12f20f22ed6d8494b8
|
|
| BLAKE2b-256 |
79568e3dfa331eec8e0010687353e1e7bec50e034ea9db946ade39e0edf46bb5
|
Provenance
The following attestation bundles were made for crn_jax-0.1.2-py3-none-any.whl:
Publisher:
release.yml on robinhenry/crn-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
crn_jax-0.1.2-py3-none-any.whl -
Subject digest:
a4245cece09d5180cc1109d49e2678970c364ac46ef9d47aad9009949e8afd2d - Sigstore transparency entry: 1361707058
- Sigstore integration time:
-
Permalink:
robinhenry/crn-jax@05a8a234047370070df0a56dffce9be40a467cb0 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/robinhenry
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@05a8a234047370070df0a56dffce9be40a467cb0 -
Trigger Event:
push
-
Statement type: