Skip to main content

Chemical reaction networks in JAX: GPU-parallel stochastic simulations.

Project description

crn-jax

CI PyPI Python License: MIT Ruff

Chemical reaction networks in JAX — a tiny, GPU-optimized Gillespie / Stochastic Simulation Algorithm (SSA) library.

Birth-death benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.   Linear-cascade benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.

Wall-time to simulate 1,000,000 independent stochastic trajectories — each a full Gillespie run of the reaction network from t=0 to t=20, sampled at 200 time points (CPU vs RTX 5090 GPU).

Install

pip install crn-jax

# with NVIDIA GPU support:
pip install crn-jax "jax[cuda12]"

# with plotting helpers:
pip install "crn-jax[examples]"

# for local development (uses Poetry):
git clone https://github.com/robinhenry/crn-jax && cd crn-jax
poetry install                    # main deps + dev tools
poetry install --with gpu         # add jax[cuda12] on an NVIDIA host

crn-jax depends on jax / jaxlib only.

Key features

  • 🎯 Exact SSA — pure-JAX implementation of the Gillespie algorithm for chemical reaction networks.
  • JIT-compiled — the entire loop compiles under jax.jit.
  • 🚀 GPU speedup — 1M+ independent trajectories on a single GPU under jax.vmap, with no Python overhead.
  • ⏱️ Discretization-safe — pending reaction times are preserved across simulation-interval boundaries, so trajectories are physically correct under discrete observations (or fixed-interval stepping).
  • 🎛️ Control-input aware — propensities take an optional input argument that can vary per-interval and per-replicate, so each of N parallel trajectories can follow its own control schedule (useful for RL-style rollouts, closed-loop experiments with per-replicate inputs, …).
  • 🧩 Bring-your-own state — the loop operates on any PyTree (NamedTuple, Flax struct dataclass, Equinox module, …).

Quickstart

A 1-species birth-death process, ∅ → X at rate λ and X → ∅ at rate μ·x, simulated for 10 independent replicates and plotted:

from typing import NamedTuple
import jax, jax.numpy as jnp
from crn_jax import simulate_trajectory, plot_trajectories

BIRTH_RATE, DEATH_RATE = 3.0, 0.1    # steady-state mean λ/μ = 30

# Define a state-holding object
class State(NamedTuple):
    time: jax.Array
    x: jax.Array
    next_reaction_time: jax.Array    # carried across intervals

# Return propensity equations as an array
# with an optional external input (unused here)
def propensities(s, _input):
    return jnp.array([BIRTH_RATE, DEATH_RATE * s.x])

# Describe how the state changes when reaction `j` fires
def apply_reaction(s, j):
    return s._replace(x=s.x + jnp.where(j == 0, 1.0, -1.0))

# Initial state
state0 = State(jnp.array(0.0), jnp.array(0.0), jnp.array(jnp.inf))

@jax.jit
@jax.vmap
def run_one(key):
    return simulate_trajectory(
        key=key,
        initial_state=state0,
        timestep=1.0,
        n_steps=200,
        # Pass our 2 custom functions defined above
        compute_propensities_fn=propensities,
        apply_reaction_fn=apply_reaction,
    )

# Simulate 100 Gillespie trajectories
states = run_one(jax.random.split(jax.random.PRNGKey(0), 100))
times = jnp.arange(1, 201) * 1.0

See the examples folder for more detailed examples.

API

# Main entry point: scan n_steps fixed-length intervals, stack the per-step states.
from crn_jax import simulate_trajectory

# Finer control: one interval at a time (RL-style), or until an absolute time.
from crn_jax.gillespie import simulate_interval, simulate_until

# Plotting helper: step-plots a single trajectory or an (N, T) ensemble.
from crn_jax import plot_trajectories

# Optional kinetic-law helpers.
from crn_jax.kinetics import hill_function, sample_lognormal
function when to reach for it
simulate_trajectory You want a full trajectory on a fixed sampling grid. Start here.
simulate_interval You're driving the system yourself, one step at a time (e.g. an RL rollout).
simulate_until You need a custom state shape or a non-uniform time grid. Fully generic.
plot_trajectories Quick look at the output.

See Also

  • GillesPy2: C++ optimized Gillespie simulations on CPU.
  • myriad-jax: RL-style decision making fully in JAX, powered by grn-jax at its core.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crn_jax-0.1.tar.gz (11.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crn_jax-0.1-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file crn_jax-0.1.tar.gz.

File metadata

  • Download URL: crn_jax-0.1.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for crn_jax-0.1.tar.gz
Algorithm Hash digest
SHA256 73396b875cf90a576385d0fb934bf346f9ba8801347fe9dc059d9871f7b8d9c5
MD5 e935cd30fed0716a664218287cd9795f
BLAKE2b-256 12abcd1835d5fdfe4331c3be16a3a0b77ab2c8818d3070085d5640a3ed565c73

See more details on using hashes here.

Provenance

The following attestation bundles were made for crn_jax-0.1.tar.gz:

Publisher: release.yml on robinhenry/crn-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file crn_jax-0.1-py3-none-any.whl.

File metadata

  • Download URL: crn_jax-0.1-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for crn_jax-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e5dfc6ff38a5a93ce982a6805c3f8fc26b4b9c69b972d9e1a5ccfb0289612e1a
MD5 6ee8e4e0a428f4d7ce6d328a5fa61d42
BLAKE2b-256 e7572d709838cf567b158e206918d08ec12f405ac72687f6153fcd6844dc6ef6

See more details on using hashes here.

Provenance

The following attestation bundles were made for crn_jax-0.1-py3-none-any.whl:

Publisher: release.yml on robinhenry/crn-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page