JAX-based integrators for ODEs, DDEs, SDEs, and SDDEs
Project description
vbjax_dynamics
A JAX-based library for numerical integration of dynamical systems.
Note: This package contains code adapted from vbjax by INS-AMU. The core integration functions in loops.py are derived from vbjax's implementation.
Features
-
ODE Integration: Ordinary Differential Equations
- Efficient loop-based integrators with JIT compilation
- Full support for
jax.vmapfor parallel trajectory computation
-
SDE Integration: Stochastic Differential Equations
make_sde(): Integration with pre-generated noise arraysmake_sde_auto(): Automatic noise generation from random keys- Euler-Maruyama scheme
- Fully reproducible with random seeds
-
DDE Integration: Delay Differential Equations
- Support for fixed delays
- History function interpolation
-
SDDE Integration: Stochastic Delay Differential Equations
- Combined stochastic and delay dynamics
-
Continuation Methods: Parameter continuation for bifurcation analysis
-
Configuration Utilities: Easy control over JAX settings
configure_jax(): Global configurationprecision_context(): Temporary precision changesprint_jax_config(): Diagnostic information
-
JAX-Native:
- JIT compilation for speed
- Automatic differentiation ready
- GPU/TPU compatible
- Pure functional approach
Installation
pip install vbjax_dynamics
For development:
pip install -e ".[dev]"
Quick Start
import jax.numpy as jnp
from jax import random, vmap
from vbjax_dynamics.loops import make_sde_auto
# Define Ornstein-Uhlenbeck process
def drift(x, p):
return -p[0] * x # -theta * x
def diffusion(x, p):
return p[1] # sigma
# Create integrator
dt = 0.01
step, loop = make_sde_auto(dt, drift, diffusion)
# Single trajectory
x0 = 2.0
params = (1.0, 0.5) # (theta, sigma)
n_steps = 1000
key = random.PRNGKey(42)
trajectory = loop(x0, n_steps, params, key)
print(f"Final value: {trajectory[-1]:.4f}")
# Multiple trajectories in parallel with vmap
n_traj = 100
keys = random.split(key, n_traj)
trajectories = vmap(lambda k: loop(x0, n_steps, params, k))(keys)
print(f"Mean: {jnp.mean(trajectories[:, -1]):.4f}")
For more examples, see the examples/ directory.
Documentation
- Examples and Tutorials: Complete guide with detailed examples for ODE, SDE, DDE, and SDDE integration
- Testing Guide: Information about running and writing tests
- Acknowledgments: Credits and attribution
Acknowledgments
This package includes code adapted from vbjax, developed by the Institut de Neurosciences de la Timone (INS-AMU). We are grateful for their work on efficient JAX-based numerical integrators.
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vbjax_dynamics-0.1.0.tar.gz.
File metadata
- Download URL: vbjax_dynamics-0.1.0.tar.gz
- Upload date:
- Size: 16.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fcabb033b2fb21424c0142179e42e0977da46dc5bf7599c2f2cba92434eb25cf
|
|
| MD5 |
2ed42ab20fb26963a1f57e6c790aef87
|
|
| BLAKE2b-256 |
8f603eef31862736abcd9b1f56ee2a44be60c5c26653d15c7b1c3186c4282054
|
File details
Details for the file vbjax_dynamics-0.1.0-py3-none-any.whl.
File metadata
- Download URL: vbjax_dynamics-0.1.0-py3-none-any.whl
- Upload date:
- Size: 10.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
628731262ea5b74e18bda1af686796380e81ce35e85aee37e6c7df26a14ebe91
|
|
| MD5 |
dbc882d5a3e5187c20a0bc97b5b7549e
|
|
| BLAKE2b-256 |
e3845ae40dd1b01d519725136cd810e1072695663e1af591403f73b64bb3f158
|