Skip to main content

Neural PDE Emulator Architectures in JAX built on top of Equinox.

Project description

PDE Emulator Architectures for Equinox.

PyPI Tests docs-latest Changelog License

InstallationDocumentationQuickstartBackgroundFeaturesBoundary ConditionsAcknowledgements

A collection of neural architectures for emulating Partial Differential Equations (PDEs) in JAX agnostic to the spatial dimension (1D, 2D, 3D) and boundary conditions (Dirichlet, Neumann, Periodic). This package is built on top of Equinox.

Installation

pip install pdequinox

Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.

Documentation

The documentation is available at fkoehler.site/pdequinox.

Quickstart

Train a UNet to become an emulator for the 1D Poisson equation.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax  # `pip install optax`
import pdequinox as pdeqx
from tqdm import tqdm  # `pip install tqdm`

force_fields, displacement_fields = pdeqx.sample_data.poisson_1d_dirichlet(
    key=jax.random.PRNGKey(0)
)

force_fields_train = force_fields[:800]
force_fields_test = force_fields[800:]
displacement_fields_train = displacement_fields[:800]
displacement_fields_test = displacement_fields[800:]

unet = pdeqx.arch.ClassicUNet(1, 1, 1, key=jax.random.PRNGKey(1))

def loss_fn(model, x, y):
    y_pref = jax.vmap(model)(x)
    return jnp.mean((y_pref - y) ** 2)

opt = optax.adam(3e-4)
opt_state = opt.init(eqx.filter(unet, eqx.is_array))

@eqx.filter_jit
def update_fn(model, state, x, y):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
    updates, new_state = opt.update(grad, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_state, loss

loss_history = []
shuffle_key = jax.random.PRNGKey(151)
for epoch in tqdm(range(100)):
    shuffle_key, subkey = jax.random.split(shuffle_key)

    for batch in pdeqx.dataloader(
        (force_fields_train, displacement_fields_train),
        batch_size=32,
        key=subkey
    ):
        unet, opt_state, loss = update_fn(
            unet,
            opt_state,
            *batch,
        )
        loss_history.append(loss)

Background

Neural Emulators are networks learned to efficienty predict physical phenomena, often associated with PDEs. In the simplest case this can be a linear advection equation, all the way to more complicated Navier-Stokes cases. If we work on Uniform Cartesian grids* (which this package assumes), one can borrow plenty of architectures from image-to-image tasks in computer vision (e.g., for segmentation). This includes:

It is interesting to note that most of these architectures resemble classical numerical methods or at least share similarities with them. For example, ConvNets (or convolutions in general) are related to finite differences, while U-Nets resemble multigrid methods. Fourier Neural Operators are related to spectral methods. The difference is that the emulators' free parameters are found based on a (data-driven) numerical optimization not a symbolic manipulation of the differential equations.

(*) This means that we essentially have a pixel or voxel grid on which space is discretized. Hence, the space can only be the scaled unit cube $\Omega = (0, L)^D$

Features

  • Based on JAX:
    • One of the best Automatic Differentiation engines (forward & reverse)
    • Automatic vectorization
    • Backend-agnostic code (run on CPU, GPU, and TPU)
  • Based on Equinox:
    • Single-Batch by design
    • Integration into the Equinox SciML ecosystem
  • Agnostic to the spatial dimension (works for 1D, 2D, and 3D)
  • Agnostic to the boundary condition (works for Dirichlet, Neumann, and periodic BCs)
  • Composability
  • Tools to count parameters and assess receptive fields

Boundary Conditions

This package assumes that the boundary condition is baked into the neural emulator. Hence, most components allow setting boundary_mode which can be "dirichlet", "neumann", or "periodic". This affects what is considered a degree of freedom in the grid.

three_boundary_conditions

Dirichlet boundaries fully eliminate degrees of freedom on the boundary. Periodic boundaries only keep one end of the domain as a degree of freedom (This package follows the convention that the left boundary is the degree of freedom). Neumann boundaries keep both ends as degrees of freedom.

Acknowledgements

Related Work

Similar packages that provide a collection of emulator architectures are PDEBench and PDEArena. With focus on Phyiscs-informed Neural Networks and Neural Operators, there are also DeepXDE and NVIDIA Modulus.

Citation

This package was developed as part of the APEBench paper (accepted at Neurips 2024), we will soon add the citation here.

Funding

The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.

License

MIT, see here


fkoehler.site  ·  GitHub @ceyron  ·  X @felix_m_koehler  ·  LinkedIn Felix Köhler

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pdequinox-0.1.2.tar.gz (37.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pdequinox-0.1.2-py3-none-any.whl (49.6 kB view details)

Uploaded Python 3

File details

Details for the file pdequinox-0.1.2.tar.gz.

File metadata

  • Download URL: pdequinox-0.1.2.tar.gz
  • Upload date:
  • Size: 37.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for pdequinox-0.1.2.tar.gz
Algorithm Hash digest
SHA256 7ee9dcbf277cbb94cda508034c0955600a03bc4c664bede5eb61b4a4b99b54c5
MD5 09c0aaee9b4c5834414a7c55b4c378c6
BLAKE2b-256 249cff9718b174e1b98d113d6d62fceeeb15fc01b6cdbace1c9939b9f2bc4464

See more details on using hashes here.

File details

Details for the file pdequinox-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: pdequinox-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 49.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for pdequinox-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b022b38eb03fa7ce4d20622cb9dd2b36f1b80730f082cd2b757279bb7d1111fa
MD5 6f9e2636b8c8a1fe0f409f44c05525fc
BLAKE2b-256 497ee0d36d65a5495eb71e1a3ce2e5f2027f9b9c86d0ab2f41a43c2735f83ebb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page