Skip to main content

JAX-based integrators for ODEs, DDEs, SDEs, and SDDEs

Project description

vbjax_dynamics

A JAX-based library for numerical integration of dynamical systems.

Note: This package contains code adapted from vbjax by INS-AMU. The core integration functions in loops.py are derived from vbjax's implementation.

Features

  • ODE Integration: Ordinary Differential Equations

    • Efficient loop-based integrators with JIT compilation
    • Full support for jax.vmap for parallel trajectory computation
  • SDE Integration: Stochastic Differential Equations

    • make_sde(): Integration with pre-generated noise arrays
    • make_sde_auto(): Automatic noise generation from random keys
    • Euler-Maruyama scheme
    • Fully reproducible with random seeds
  • DDE Integration: Delay Differential Equations

    • Support for fixed delays
    • History function interpolation
  • SDDE Integration: Stochastic Delay Differential Equations

    • Combined stochastic and delay dynamics
  • Continuation Methods: Parameter continuation for bifurcation analysis

  • Configuration Utilities: Easy control over JAX settings

    • configure_jax(): Global configuration
    • precision_context(): Temporary precision changes
    • print_jax_config(): Diagnostic information
  • JAX-Native:

    • JIT compilation for speed
    • Automatic differentiation ready
    • GPU/TPU compatible
    • Pure functional approach

Installation

pip install vbjax_dynamics

For development:

pip install -e ".[dev]"

Quick Start

import jax.numpy as jnp
from jax import random, vmap
from vbjax_dynamics.loops import make_sde_auto

# Define Ornstein-Uhlenbeck process
def drift(x, p):
    return -p[0] * x  # -theta * x

def diffusion(x, p):
    return p[1]  # sigma

# Create integrator
dt = 0.01
step, loop = make_sde_auto(dt, drift, diffusion)

# Single trajectory
x0 = 2.0
params = (1.0, 0.5)  # (theta, sigma)
n_steps = 1000
key = random.PRNGKey(42)

trajectory = loop(x0, n_steps, params, key)
print(f"Final value: {trajectory[-1]:.4f}")

# Multiple trajectories in parallel with vmap
n_traj = 100
keys = random.split(key, n_traj)
trajectories = vmap(lambda k: loop(x0, n_steps, params, k))(keys)
print(f"Mean: {jnp.mean(trajectories[:, -1]):.4f}")

For more examples, see the examples/ directory.

Documentation

Acknowledgments

This package includes code adapted from vbjax, developed by the Institut de Neurosciences de la Timone (INS-AMU). We are grateful for their work on efficient JAX-based numerical integrators.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vbjax_dynamics-0.1.0.tar.gz (16.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vbjax_dynamics-0.1.0-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file vbjax_dynamics-0.1.0.tar.gz.

File metadata

  • Download URL: vbjax_dynamics-0.1.0.tar.gz
  • Upload date:
  • Size: 16.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for vbjax_dynamics-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fcabb033b2fb21424c0142179e42e0977da46dc5bf7599c2f2cba92434eb25cf
MD5 2ed42ab20fb26963a1f57e6c790aef87
BLAKE2b-256 8f603eef31862736abcd9b1f56ee2a44be60c5c26653d15c7b1c3186c4282054

See more details on using hashes here.

File details

Details for the file vbjax_dynamics-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: vbjax_dynamics-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 10.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for vbjax_dynamics-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 628731262ea5b74e18bda1af686796380e81ce35e85aee37e6c7df26a14ebe91
MD5 dbc882d5a3e5187c20a0bc97b5b7549e
BLAKE2b-256 e3845ae40dd1b01d519725136cd810e1072695663e1af591403f73b64bb3f158

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page