Efficient differentiable PDE solvers in JAX.
Project description
Efficient Differentiable n-d PDE solvers built on top of JAX & Equinox.
Installation • Quickstart • Equations • Features • Documentation • Background • Citation
Exponax solves partial differential equations in 1D, 2D, and 3D on periodic
domains highly efficiently using Fourier spectral methods and exponential time
differencing. It ships more than 46 PDE solvers covering linear, nonlinear, and
reaction-diffusion dynamics. Built entirely on
JAX and
Equinox, every solver is
automatically differentiable, JIT-compilable, and GPU/TPU-ready — making it
ideal for physics-based deep learning workflows.
Installation
pip install exponax
Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.
Quickstart
Simulate the chaotic Kuramoto-Sivashinsky equation in 1D — a single stepper object, one line to roll out 500 time steps:
import jax
import exponax as ex
import matplotlib.pyplot as plt
ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
num_spatial_dims=1, domain_extent=100.0,
num_points=200, dt=0.1,
)
u_0 = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims=1, cutoff=5
)(num_points=200, key=jax.random.PRNGKey(0))
trajectory = ex.rollout(ks_stepper, 500, include_init=True)(u_0)
plt.imshow(trajectory[:, 0, :].T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2, origin="lower")
plt.xlabel("Time"); plt.ylabel("Space"); plt.show()
Because every stepper is a differentiable JAX function, you can freely compose
it with jax.grad, jax.vmap, and jax.jit:
# Jacobian of the stepper function
jacobian = jax.jacfwd(ks_stepper)(u_0)
For a next step, check out this tutorial on 1D
Advection
that explains the basics of Exponax.
Built-in Equations
Linear
| Equation | Stepper | Dimensions |
|---|---|---|
| Advection: $u_t + c \cdot \nabla u = 0$ | Advection |
1D, 2D, 3D |
| Diffusion: $u_t = \nu \Delta u$ | Diffusion |
1D, 2D, 3D |
| Advection-Diffusion: $u_t + c \cdot \nabla u = \nu \Delta u$ | AdvectionDiffusion |
1D, 2D, 3D |
| Dispersion: $u_t = \xi \nabla^3 u$ | Dispersion |
1D, 2D, 3D |
| Hyper-Diffusion: $u_t = -\zeta \Delta^2 u$ | HyperDiffusion |
1D, 2D, 3D |
| Wave: $u_{tt} = c^2 \Delta u$ | Wave |
1D, 2D, 3D |
Nonlinear
| Equation | Stepper | Dimensions |
|---|---|---|
| Burgers: $u_t + \frac{1}{2} \nabla \cdot (u \otimes u) = \nu \Delta u$ | Burgers |
1D, 2D, 3D |
| Korteweg-de Vries: $u_t + \frac{1}{2} \nabla \cdot (u \otimes u) - \nabla^3 u = \mu \Delta u$ | KortewegDeVries |
1D, 2D, 3D |
| Kuramoto-Sivashinsky: $u_t + \frac{1}{2} |\nabla u|^2 + \Delta u + \Delta^2 u = 0$ | KuramotoSivashinsky |
1D, 2D, 3D |
| KS (conservative): $u_t + \frac{1}{2} \nabla \cdot (u \otimes u) + \Delta u + \Delta^2 u = 0$ | KuramotoSivashinskyConservative |
1D, 2D, 3D |
| Navier-Stokes (vorticity): $\omega_t + (u \cdot \nabla)\omega = \nu \Delta \omega$ | NavierStokesVorticity |
2D |
| Kolmogorov Flow (vorticity): $\omega_t + (u \cdot \nabla)\omega = \nu \Delta \omega + f$ | KolmogorovFlowVorticity |
2D |
| Navier-Stokes (velocity): $u_t = \nu \Delta u + \mathcal{P}(u \times \omega)$ | NavierStokesVelocity |
3D |
| Kolmogorov Flow (velocity): $u_t = \nu \Delta u + \mathcal{P}(u \times \omega) + f$ | KolmogorovFlowVelocity |
3D |
Reaction-Diffusion
| Equation | Stepper | Dimensions |
|---|---|---|
| Fisher-KPP: $u_t = \nu \Delta u + r, u(1 - u)$ | reaction.FisherKPP |
1D, 2D, 3D |
| Allen-Cahn: $u_t = \nu \Delta u + c_1 u + c_3 u^3$ | reaction.AllenCahn |
1D, 2D, 3D |
| Cahn-Hilliard: $u_t = \nu \Delta(u^3 + c_1 u - \gamma \Delta u)$ | reaction.CahnHilliard |
1D, 2D, 3D |
| Gray-Scott: $u_t = \nu_1 \Delta u + f(1-u) - uv^2, \quad v_t = \nu_2 \Delta v - (f+k)v + uv^2$ | reaction.GrayScott |
1D, 2D, 3D |
| Swift-Hohenberg: $u_t = ru - (k + \Delta)^2 u + g(u)$ | reaction.SwiftHohenberg |
1D, 2D, 3D |
Generic stepper families (for advanced / custom dynamics)
These parametric families generalize the concrete steppers above. Each comes in three flavors: physical coefficients, normalized, and difficulty-based.
| Family | Nonlinearity | Generalizes |
|---|---|---|
GeneralLinearStepper |
None | Advection, Diffusion, Dispersion, etc. |
GeneralConvectionStepper |
Quadratic convection | Burgers, KdV, KS Conservative |
GeneralGradientNormStepper |
Gradient norm | Kuramoto-Sivashinsky |
GeneralVorticityConvectionStepper |
Vorticity convection (2D only) | Navier-Stokes, Kolmogorov Flow |
GeneralPolynomialStepper |
Arbitrary polynomial | Fisher-KPP, Allen-Cahn, etc. |
GeneralNonlinearStepper |
Convection + gradient norm + polynomial | Most of the above |
See the normalized & difficulty interface docs for details.
Features
- Hardware-agnostic — run on CPU, GPU, or TPU in single or double precision.
- Fully differentiable — compute gradients of solutions w.r.t. initial conditions, PDE parameters, or neural network weights when composed with PDE solvers via
jax.grad. - Vectorized batching — advance multiple states or sweep over parameter grids in parallel using
jax.vmap(andeqx.filter_vmap). - Deep-learning native — every stepper is an Equinox Module, composable with neural networks out of the box.
- Lightweight design — no custom grid or state objects; everything is plain
jax.numpyarrays and callable PyTrees. - Initial conditions — library of random IC distributions (truncated Fourier series, Gaussian random fields, etc.).
- Utilities — spectral derivatives, grid creation, autoregressive rollout, interpolation, and more.
- Extensible — add new PDEs by subclassing
BaseStepper.
Documentation
Documentation is available at fkoehler.site/exponax. Key pages:
- 1D Advection Tutorial — learn the basics
- Solver Showcase 1D / 2D / 3D — visual gallery of all dynamics
- Creating Your Own Solvers — extend Exponax with custom PDEs
- Training a Neural Operator — use
Exponaxfor synthetic data generation and training of a neural emulator - Stepper Overview — API reference for all steppers
- Performance Hints — tips for fast simulations
Background
Exponax solves semi-linear PDEs of the form
$$ \partial u / \partial t = Lu + N(u), $$
where $L$ is a linear differential operator and $N$ is a nonlinear differential operator. The linear part is solved exactly via a matrix exponential in Fourier space, while the nonlinear part is integrated using exponential time differencing Runge-Kutta (ETDRK) schemes of order 1 through 4. The complex contour integral method of Kassam & Trefethen is used for numerical stability.
By restricting to periodic domains on scaled hypercubes with uniform Cartesian grids, all transforms reduce to FFTs — yielding blazing-fast simulations. For example, 50 trajectories of the 2D Kuramoto-Sivashinsky equation (200 time steps, 128x128 grid) are generated in under a second on a modern GPU.
References
- Cox, S.M. and Matthews, P.C. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455. doi:10.1006/jcph.2002.6995
- Kassam, A.K. and Trefethen, L.N. "Fourth-order time-stepping for stiff PDEs." SIAM Journal on Scientific Computing 26.4 (2005): 1214-1233. doi:10.1137/S1064827502410633
- Montanelli, H. and Bootland, N. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327. doi:10.1016/j.matcom.2020.06.008
Related & Motivation
This package is greatly inspired by the
spinX module of the
ChebFun package in MATLAB. spinX served as a
reliable data generator for early works in physics-based deep learning, e.g.,
DeepHiddenPhysics
and Fourier Neural
Operators.
However, due to the two-language barrier, dynamically calling MATLAB solvers
from Python-based deep learning workflows is hard to impossible. This also
excludes the option to differentiate through them — ruling out
differentiable-physics approaches like solver-in-the-loop correction or
diverted-chain training.
We view Exponax as a spiritual successor of spinX. JAX, as the
computational backend, elevates the power of this solver type with automatic
vectorization (jax.vmap), backend-agnostic execution (CPU/GPU/TPU), and tight
integration for deep learning via its versatile automatic differentiation
engine. With reproducible randomness in JAX, datasets can be re-created in
seconds — no need to ever write them to disk.
Beyond ChebFun, other popular pseudo-spectral implementations include Dedalus in the Python world and FourierFlows.jl in the Julia ecosystem (the latter was especially helpful for verifying our implementation of the contour integral method and dealiasing).
Citation
Exponax was developed as part of the
APEBench benchmark suite for
autoregressive neural emulators of PDEs. The accompanying paper was accepted at
NeurIPS 2024. If you find this package useful for your research, please
consider citing it:
@article{koehler2024apebench,
title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
author={Felix Koehler and Simon Niedermayr and R{\"u}diger Westermann and Nils Thuerey},
journal={Advances in Neural Information Processing Systems (NeurIPS)},
volume={38},
year={2024}
}
If you enjoy the project, feel free to give it a star on GitHub!
Funding
The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.
License
MIT, see here
fkoehler.site · GitHub @ceyron · X @felix_m_koehler · LinkedIn Felix Köhler
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file exponax-0.2.0.tar.gz.
File metadata
- Download URL: exponax-0.2.0.tar.gz
- Upload date:
- Size: 137.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4d4d88cd0313def80040a370b67299851edab2411890c01305c6eb167a5b93f1
|
|
| MD5 |
3cc7267af14e1230bc552101fd263731
|
|
| BLAKE2b-256 |
8c90695f379846a7bfd8e2fa0f837761c19c42c1982e64ab6f7023bd372d4b96
|
Provenance
The following attestation bundles were made for exponax-0.2.0.tar.gz:
Publisher:
publish.yml on Ceyron/exponax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
exponax-0.2.0.tar.gz -
Subject digest:
4d4d88cd0313def80040a370b67299851edab2411890c01305c6eb167a5b93f1 - Sigstore transparency entry: 973172714
- Sigstore integration time:
-
Permalink:
Ceyron/exponax@1b25e9438e171d6779bde94f3cc746457ff48069 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/Ceyron
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@1b25e9438e171d6779bde94f3cc746457ff48069 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file exponax-0.2.0-py3-none-any.whl.
File metadata
- Download URL: exponax-0.2.0-py3-none-any.whl
- Upload date:
- Size: 159.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7e013aecd3f6af641a509a8a72b8b0f18067c6ca24194146ed52b370a026be34
|
|
| MD5 |
e7f6f80f1f9f2bbb394fe8e350398e70
|
|
| BLAKE2b-256 |
dafc0588f4a5f340d011f46e5ac907f6141af557ef6d57e49b3237b2024fd6f4
|
Provenance
The following attestation bundles were made for exponax-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on Ceyron/exponax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
exponax-0.2.0-py3-none-any.whl -
Subject digest:
7e013aecd3f6af641a509a8a72b8b0f18067c6ca24194146ed52b370a026be34 - Sigstore transparency entry: 973172718
- Sigstore integration time:
-
Permalink:
Ceyron/exponax@1b25e9438e171d6779bde94f3cc746457ff48069 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/Ceyron
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@1b25e9438e171d6779bde94f3cc746457ff48069 -
Trigger Event:
workflow_dispatch
-
Statement type: