Skip to main content

Neural PDE Emulator Architectures in JAX built on top of Equinox.

Project description

⚠️ ⚠️ ⚠️ I am currently setting up this public repository. Please be patient. ⚠️ ⚠️ ⚠️


PDEquinox

PDE Emulator Architectures in Equinox.

InstallationQuickstartBackgroundFeaturesBoundary ConditionsConstructorsRelatedLicense

Installation

Clone the repository, navigate to the folder and install the package with pip:

pip install .

Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.

Quickstart

Train a UNet to become an emulator for the 1D Poisson equation.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax  # `pip install optax`
import pdequinox as pdeqx
from tqdm import tqdm  # `pip install tqdm`

force_fields, displacement_fields = pdeqx.sample_data.poisson_1d_dirichlet(
    key=jax.random.PRNGKey(0)
)

force_fields_train = force_fields[:800]
force_fields_test = force_fields[800:]
displacement_fields_train = displacement_fields[:800]
displacement_fields_test = displacement_fields[800:]

unet = pdeqx.arch.ClassicUNet(1, 1, 1, key=jax.random.PRNGKey(1))

def loss_fn(model, x, y):
    y_pref = jax.vmap(model)(x)
    return jnp.mean((y_pref - y) ** 2)

opt = optax.adam(3e-4)
opt_state = opt.init(eqx.filter(unet, eqx.is_array))

@eqx.filter_jit
def update_fn(model, state, x, y):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
    updates, new_state = opt.update(grad, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_state, loss

loss_history = []
shuffle_key = jax.random.PRNGKey(151)
for epoch in tqdm(range(100)):
    shuffle_key, subkey = jax.random.split(shuffle_key)

    for batch in pdeqx.dataloader(
        (force_fields_train, displacement_fields_train),
        batch_size=32,
        key=subkey
    ):
        unet, opt_state, loss = update_fn(
            unet,
            opt_state,
            *batch,
        )
        loss_history.append(loss)

Background

Neural Emulators are networks learned to efficienty predict physical phenomena, often associated with PDEs. In the simplest case this can be a linear advection equation, all the way to more complicated Navier-Stokes cases. If we work on Uniform Cartesian grids* (which this package assumes), one can borrow plenty of architectures from image-to-image tasks in computer vision (e.g., for segmentation). This includes:

It is interesting to note that most of these architectures resemble classical numerical methods or at least share similarities with them. For example, ConvNets (or convolutions in general) are related to finite differences, while U-Nets resemble multigrid methods. Fourier Neural Operators are related to spectral methods. The difference is that the emulators' free parameters are found based on a (data-driven) numerical optimization not a symbolic manipulation of the differential equations.

(*) This means that we essentially have a pixel or voxel grid on which space is discretized. Hence, the space can only be the scaled unit cube $\Omega = (0, L)^D$

Features

  • Based on JAX:
    • One of the best Automatic Differentiation engines (forward & reverse)
    • Automatic vectorization
    • Backend-agnostic code (run on CPU, GPU, and TPU)
  • Based on Equinox:
    • Single-Batch by design
    • Integration into the Equinox SciML ecosystem
  • Agnostic to the spatial dimension (works for 1D, 2D, and 3D)
  • Agnostic to the boundary condition (works for Dirichlet, Neumann, and periodic BCs)
  • Composability
  • Tools to count parameters and assess receptive fields

Boundary Conditions

This package assumes that the boundary condition is baked into the neural emulator. Hence, most components allow setting boundary_mode which can be "dirichlet", "neumann", or "periodic". This affects what is considered a degree of freedom in the grid.

Dirichlet boundaries fully eliminate degrees of freedom on the boundary. Periodic boundaries only keep one end of the domain as a degree of freedom (This package follows the convention that the left boundary is the degree of freedom). Neumann boundaries keep both ends as degrees of freedom.

Constructors

There are two primary architectural constructors for Sequential and Hierarchical Networks that allow for composability with the PDEquinox blocks.

Sequential Constructor

The squential network constructor is defined by:

  • a lifting block $\mathcal{L}$
  • $N$ blocks $\left { \mathcal{B}i \right}{i=1}^N$
  • a projection block $\mathcal{P}$
  • the hidden channels within the sequential processing
  • the number of blocks $N$ (one can also supply a list of hidden channels if they shall be different between blocks)

Hierarchical Constructor

The hierarchical network constructor is defined by:

  • a lifting block $\mathcal{L}$
  • The number of levels $D$ (i.e., the number of additional hierarchies). Setting $D = 0$ recovers the sequential processing.
  • a list of $D$ blocks $\left { \mathcal{D}i \right}{i=1}^D$ for downsampling, i.e. mapping downwards to the lower hierarchy (oftentimes this is that they halve the spatial axes while keeping the number of channels)
  • a list of $D$ blocks $\left { \mathcal{B}i^l \right}{i=1}^D$ for processing in the left arc (oftentimes this changes the number of channels, e.g. doubles it such that the combination of downsampling and left processing halves the spatial resolution and doubles the feature count)
  • a list of $D$ blocks $\left { \mathcal{U}i \right}{i=1}^D$ for upsamping, i.e., mapping upwards to the higher hierarchy (oftentimes this doubles the spatial resolution; at the same time it halves the feature count such that we can concatenate a skip connection)
  • a list of $D$ blocks $\left { \mathcal{B}i^r \right}{i=1}^D$ for processing in the right arc (oftentimes this changes the number of channels, e.g. halves it such that the combination of upsampling and right processing doubles the spatial resolution and halves the feature count)
  • a projection block $\mathcal{P}$
  • the hidden channels within the hierarchical processing (if just an integer is provided; this is assumed to be the number of hidden channels in the highest hierarchy.)

Beyond Architectural Constructors

For completion, pdequinox.arch also provides a ConvNet which is a simple feed-forward convolutional network. It also provides MLP which is a dense networks which also requires pre-defining the number of resolution points.

Related

Similar packages that provide a collection of emulator architectures are PDEBench and PDEArena. With focus on Phyiscs-informed Neural Networks and Neural Operators, there are also DeepXDE and NVIDIA Modulus.

License

MIT, see here


fkoehler.site  ·  GitHub @ceyron  ·  X @felix_m_koehler

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pdequinox-0.1.0.tar.gz (36.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pdequinox-0.1.0-py3-none-any.whl (49.0 kB view details)

Uploaded Python 3

File details

Details for the file pdequinox-0.1.0.tar.gz.

File metadata

  • Download URL: pdequinox-0.1.0.tar.gz
  • Upload date:
  • Size: 36.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for pdequinox-0.1.0.tar.gz
Algorithm Hash digest
SHA256 07f7516fe26823e6c3b71f1ed5a170e97cc34ff1d1349435d4b7469adc540d3a
MD5 449ca8ce4724b3349fd9ec3895847549
BLAKE2b-256 8977ffd566a0c78c6d1dfac2b84ef07cd71d090ab2fb9b7a50af4c4db0e2bdd8

See more details on using hashes here.

File details

Details for the file pdequinox-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pdequinox-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 49.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for pdequinox-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 be618d9629751ea274eb7f3ba533de8efa7c3435e2159c8bb29dea3d67811104
MD5 38fb9d198ef16f3b51f380690040c8fc
BLAKE2b-256 34fa34d61900dd7bfc5555b71f9386e41884b0eae6b060fe4896c13ff4f9f88f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page