Skip to main content

Chemical reaction networks in JAX: GPU-parallel stochastic simulations.

Project description

crn-jax

CI PyPI Python License: MIT Ruff

Chemical reaction networks in JAX — a tiny, GPU-optimized Gillespie / Stochastic Simulation Algorithm (SSA) library.

Birth-death benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.   Linear-cascade benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.

Wall-time to simulate 1,000,000 independent stochastic trajectories — each a full Gillespie run of the reaction network from t=0 to t=20, sampled at 200 time points (CPU vs RTX 5090 GPU).

Install

pip install crn-jax

# with NVIDIA GPU support:
pip install crn-jax "jax[cuda12]"

# with plotting helpers:
pip install "crn-jax[examples]"

# for local development (uses Poetry):
git clone https://github.com/robinhenry/crn-jax && cd crn-jax
poetry install                    # main deps + dev tools
poetry install --with gpu         # add jax[cuda12] on an NVIDIA host

crn-jax depends on jax / jaxlib only.

Key features

  • 🎯 Exact SSA — pure-JAX implementation of the Gillespie algorithm for chemical reaction networks.
  • JIT-compiled — the entire loop compiles under jax.jit.
  • 🚀 GPU speedup — 1M+ independent trajectories on a single GPU under jax.vmap, with no Python overhead.
  • ⏱️ Discretization-safe — pending reaction times are preserved across simulation-interval boundaries, so trajectories are physically correct under discrete observations (or fixed-interval stepping).
  • 🎛️ Control-input aware — propensities take an optional input argument that can vary per-interval and per-replicate, so each of N parallel trajectories can follow its own control schedule (useful for RL-style rollouts, closed-loop experiments with per-replicate inputs, …).
  • 🧩 Bring-your-own state — the loop operates on any PyTree (NamedTuple, Flax struct dataclass, Equinox module, …).

Quickstart

A 1-species birth-death process, ∅ → X at rate λ and X → ∅ at rate μ·x, simulated for 10 independent replicates and plotted:

from typing import NamedTuple
import jax, jax.numpy as jnp
from crn_jax import simulate_trajectory, plot_trajectories

BIRTH_RATE, DEATH_RATE = 3.0, 0.1    # steady-state mean λ/μ = 30

# Define a state-holding object
class State(NamedTuple):
    time: jax.Array
    x: jax.Array
    next_reaction_time: jax.Array    # carried across intervals

# Return propensity equations as an array
# with an optional external input (unused here)
def propensities(s, _input):
    return jnp.array([BIRTH_RATE, DEATH_RATE * s.x])

# Describe how the state changes when reaction `j` fires
def apply_reaction(s, j):
    return s._replace(x=s.x + jnp.where(j == 0, 1.0, -1.0))

# Initial state
state0 = State(jnp.array(0.0), jnp.array(0.0), jnp.array(jnp.inf))

@jax.jit
@jax.vmap
def run_one(key):
    return simulate_trajectory(
        key=key,
        initial_state=state0,
        timestep=1.0,
        n_steps=200,
        # Pass our 2 custom functions defined above
        compute_propensities_fn=propensities,
        apply_reaction_fn=apply_reaction,
    )

# Simulate 100 Gillespie trajectories
states = run_one(jax.random.split(jax.random.PRNGKey(0), 100))
times = jnp.arange(1, 201) * 1.0

See the examples folder for more detailed examples.

API

# Main entry point: scan n_steps fixed-length intervals, stack the per-step states.
from crn_jax import simulate_trajectory

# Finer control: one interval at a time (RL-style), or until an absolute time.
from crn_jax.gillespie import simulate_interval, simulate_until

# Plotting helper: step-plots a single trajectory or an (N, T) ensemble.
from crn_jax import plot_trajectories

# Optional kinetic-law helpers.
from crn_jax.kinetics import hill_function, sample_lognormal
function when to reach for it
simulate_trajectory You want a full trajectory on a fixed sampling grid. Start here.
simulate_interval You're driving the system yourself, one step at a time (e.g. an RL rollout).
simulate_until You need a custom state shape or a non-uniform time grid. Fully generic.
plot_trajectories Quick look at the output.

See Also

  • GillesPy2: C++ optimized Gillespie simulations on CPU.
  • myriad-jax: RL-style decision making fully in JAX, powered by grn-jax at its core.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crn_jax-0.1.2.tar.gz (11.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crn_jax-0.1.2-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file crn_jax-0.1.2.tar.gz.

File metadata

  • Download URL: crn_jax-0.1.2.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for crn_jax-0.1.2.tar.gz
Algorithm Hash digest
SHA256 f82b04a5aeff869395220a65b9acca4f9589c07fcc905e268b2ac355de790612
MD5 79ccd32d9a3824027e156571cf7c6bcf
BLAKE2b-256 7cb2c833fbf74d98441fb69c4592c4d6a1ce4233dfb963cd78d5819fb61755b3

See more details on using hashes here.

Provenance

The following attestation bundles were made for crn_jax-0.1.2.tar.gz:

Publisher: release.yml on robinhenry/crn-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file crn_jax-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: crn_jax-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for crn_jax-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a4245cece09d5180cc1109d49e2678970c364ac46ef9d47aad9009949e8afd2d
MD5 e49b51a971b45f12f20f22ed6d8494b8
BLAKE2b-256 79568e3dfa331eec8e0010687353e1e7bec50e034ea9db946ade39e0edf46bb5

See more details on using hashes here.

Provenance

The following attestation bundles were made for crn_jax-0.1.2-py3-none-any.whl:

Publisher: release.yml on robinhenry/crn-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page