Efficient differentiable PDE solvers in JAX.
Project description
Efficient Differentiable n-d PDE solvers built on top of JAX & Equinox.
Installation • Documentation • Quickstart • Features • Background • Motivation • Citation
Exponax is a suite for building Fourier spectral ETDRK time-steppers for
semi-linear PDEs in 1d, 2d, and 3d. There are many pre-built dynamics and plenty
of helpful utilities. It is extremely efficient, is differentiable (due to being
fully written in JAX), and embeds seamlessly into deep learning.
Installation
pip install exponax
Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.
Documentation
Documentation is available at fkoehler.site/exponax.
Quickstart
1d Kuramoto-Sivashinsky Equation.
import jax
import exponax as ex
import matplotlib.pyplot as plt
ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
num_spatial_dims=1, domain_extent=100.0,
num_points=200, dt=0.1,
)
u_0 = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims=1, cutoff=5
)(num_points=200, key=jax.random.PRNGKey(0))
trajectory = ex.rollout(ks_stepper, 500, include_init=True)(u_0)
plt.imshow(trajectory[:, 0, :].T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2, origin="lower")
plt.xlabel("Time"); plt.ylabel("Space"); plt.show()
For a next step, check out this tutorial on 1D
Advection
that explains the basics of Exponax.
Features
- JAX as the computational backend:
- Backend agnostic code - run on CPU, GPU, or TPU, in both single and double precision.
- Automatic differentiation over the timesteppers - compute gradients of solutions with respect to initial conditions, parameters, etc.
- Also helpful for tight integration with Deep Learning since each timestepper is just an Equinox Module.
- Automatic Vectorization using
jax.vmap(orequinox.filter_vmap) allowing to advance multiple states in time or instantiate multiple solvers at a time that operate efficiently in batch.
- Lightweight Design without custom types. There is no
gridorstateobject. Everything is based on JAX arrays. Timesteppers are callable PyTrees. - More than 46 pre-built dynamics across 1d, 2d, and 3d:
- Linear PDEs (advection, diffusion, dispersion, etc.)
- Nonlinear PDEs (Burgers, Kuramoto-Sivashinsky, Korteweg-de Vries, Navier-Stokes, etc.)
- Reaction-Diffusion (Gray-Scott, Swift-Hohenberg, etc.)
- Collection of initial condition distributions (truncated Fourier series, Gaussian Random Fields, etc.)
- Utilities for spectral derivatives, grid creation, autoregressive rollout, interpolation, etc.
- Easily extendable to new PDEs by subclassing from the
BaseSteppermodule. - An alternative, reduced interface allowing to define PDE dynamics using normalized or difficulty-based identifiers.
Background
Exponax supports the efficient solution of (semi-linear) partial differential equations on periodic domains in arbitrary dimensions. Those are PDEs of the form
$$ \partial u/ \partial t = Lu + N(u), $$
where $L$ is a linear differential operator and $N$ is a nonlinear differential operator. The linear part can be exactly solved using a (matrix) exponential, and the nonlinear part is approximated using Runge-Kutta methods of various orders. These methods have been known in various disciplines in science for a long time and have been unified for a first time by Cox & Matthews [1]. In particular, this package uses the complex contour integral method of Kassam & Trefethen [2] for numerical stability. The package is restricted to the original first, second, third and fourth order method. A recent study by Montanelli & Bootland [3] showed that the original ETDRK4 method is still one of the most efficient methods for these types of PDEs.
We focus on periodic domains on scaled hypercubes with a uniform Cartesian discretization. This allows using the Fast Fourier Transform resulting in blazing fast simulations. For example, a dataset of trajectories for the 2d Kuramoto-Sivashinsky equation with 50 initial conditions over 200 time steps with a 128x128 discretization is created in less than a second on a modern GPU.
[1] Cox, Steven M., and Paul C. Matthews. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455.
[2] Kassam, A.K. and Trefethen, L.N., 2005. Fourth-order time-stepping for stiff PDEs. SIAM Journal on Scientific Computing, 26(4), pp.1214-1233.
[3] Montanelli, Hadrien, and Niall Bootland. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327.
Related & Motivation
This package is greatly inspired by the chebfun
library in MATLAB, in particular the
spinX (Stiff Pde INtegrator
in X dimensions) module within it. These MATLAB utilities have been used
extensively as a data generator in early works for supervised physics-informed
ML, e.g., the
DeepHiddenPhysics
and Fourier Neural
Operators
(the links show where in their public repos they use the spinX module). The
approach of pre-sampling the solvers, writing out the trajectories, and then
using them for supervised training worked for these problems, but of course
limits the scope to purely supervised problem. Modern research ideas like
correcting coarse solvers (see for instance the Solver-in-the-Loop
paper or the ML-accelerated CFD
paper) require a coarse solvers to be
differentiable. Some ideas
of diverted chain training also requires the fine solver to be differentiable.
Even for applications without differentiable solvers, we still have the
interface problem with legacy solvers (like the MATLAB ones). Hence, we
cannot easily query them "on-the-fly" for sth like active learning tasks, nor do
they run efficiently on hardware accelerators (GPUs, TPUs, etc.). Additionally,
they were not designed with batch execution (in the sense of vectorized
application) in mind which we get more or less for free by jax.vmap. With the
reproducible randomness of JAX we might not even have to ever write out a
dataset and can re-create it in seconds!
This package also took much inspiration from the FourierFlows.jl in the Julia ecosystem, especially for checking the implementation of the contour integral method of [2] and how to handle (de)aliasing.
Citation
This package was developed as part of the APEBench paper (arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it:
@article{koehler2024apebench,
title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
author={Felix Koehler and Simon Niedermayr and R{\"}udiger Westermann and Nils Thuerey},
journal={Advances in Neural Information Processing Systems (NeurIPS)},
volume={38},
year={2024}
}
(Feel free to also give the project a star on GitHub if you like it.)
Here you can find the APEBench benchmark suite.
Funding
The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.
License
MIT, see here
fkoehler.site · GitHub @ceyron · X @felix_m_koehler · LinkedIn Felix Köhler
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file exponax-0.1.1.tar.gz.
File metadata
- Download URL: exponax-0.1.1.tar.gz
- Upload date:
- Size: 102.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e00ca3e9ffe431999e6813e7441d3129b46fb8521f5c984b0686a0bf882a629e
|
|
| MD5 |
fbf7411b53229a902136f53b0debec76
|
|
| BLAKE2b-256 |
6a488ea522d53372b72f197aceaaf3b05feefc3586d562e37e6b9181b10c9230
|
Provenance
The following attestation bundles were made for exponax-0.1.1.tar.gz:
Publisher:
publish.yml on Ceyron/exponax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
exponax-0.1.1.tar.gz -
Subject digest:
e00ca3e9ffe431999e6813e7441d3129b46fb8521f5c984b0686a0bf882a629e - Sigstore transparency entry: 939856597
- Sigstore integration time:
-
Permalink:
Ceyron/exponax@b9397f2575e67520590b70d6f3e538cd8599e648 -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/Ceyron
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@b9397f2575e67520590b70d6f3e538cd8599e648 -
Trigger Event:
release
-
Statement type:
File details
Details for the file exponax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: exponax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 147.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac37649fd7b041168ee0817a80bf36bed77fffed98f0fa7ca764d6c060646f5a
|
|
| MD5 |
b1a4f89c28472682769f6f0f144aa8fb
|
|
| BLAKE2b-256 |
a0a4c5e760e5af6cee9d1718c048b29bce64609f80bdf054b27a5604b023cb93
|
Provenance
The following attestation bundles were made for exponax-0.1.1-py3-none-any.whl:
Publisher:
publish.yml on Ceyron/exponax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
exponax-0.1.1-py3-none-any.whl -
Subject digest:
ac37649fd7b041168ee0817a80bf36bed77fffed98f0fa7ca764d6c060646f5a - Sigstore transparency entry: 939856611
- Sigstore integration time:
-
Permalink:
Ceyron/exponax@b9397f2575e67520590b70d6f3e538cd8599e648 -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/Ceyron
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@b9397f2575e67520590b70d6f3e538cd8599e648 -
Trigger Event:
release
-
Statement type: