Skip to main content

GPU+autodiff-capable ODE/SDE/CDE solvers written in JAX.

Project description

Diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a JAX-based library providing numerical differential equation solvers.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
  • lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
  • vmappable everything (including the region of integration);
  • using a PyTree as the state;
  • dense solutions;
  • multiple adjoint methods for backpropagation;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.

Installation

pip install diffrax

Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.10.11+.

Documentation

Available at https://docs.kidger.site/diffrax.

Quick example

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

Here, Dopri5 refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.

Citation

If you found this library useful in academic research, please cite: (arXiv link)

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

(Also consider starring the project on GitHub.)

See also: other libraries in the JAX ecosystem

jaxtyping: type annotations for shape/dtype of arrays.

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Optimistix: root finding, minimisation, fixed points, and least squares.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffrax-0.5.0.tar.gz (108.4 kB view details)

Uploaded Source

Built Distribution

diffrax-0.5.0-py3-none-any.whl (141.7 kB view details)

Uploaded Python 3

File details

Details for the file diffrax-0.5.0.tar.gz.

File metadata

  • Download URL: diffrax-0.5.0.tar.gz
  • Upload date:
  • Size: 108.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for diffrax-0.5.0.tar.gz
Algorithm Hash digest
SHA256 2e66701b545798818af34a9132887b354c93b5b47e133c9986e23a7cf600a937
MD5 cda08c71c1f74c7240271ee6901632f0
BLAKE2b-256 ac5029928c72f2e82b7580249e65a619b41e1f7495f7503254d4ce6b4897effb

See more details on using hashes here.

File details

Details for the file diffrax-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: diffrax-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 141.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for diffrax-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 58b526490fae982fdc9ca2a1138d5dc20300060618fe0060b017d95ce9c84036
MD5 920fd4b94a98b3eb99abcf2c542b543c
BLAKE2b-256 d88094c7a22832246f4e8b8ad2dba701b16d7aa97881cef8f6c05ffd97bc0c63

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page