Skip to main content

Training methodologies for Autoregressive Neural Emulators in JAX built on top of Equinox & Optax.

Project description


Trainax

Learning Methodologies for Autoregressive Neural Emulators.

InstallationQuickstartBackgroundFeaturesTaxonomyLicense

Installation

Clone the repository, navigate to the folder and install the package with pip:

pip install .

Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.

Quickstart

Train a kernel size 2 linear convolution (no bias) to become an emulator for the 1D advection problem.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax  # pip install optax
import trainax as tx

CFL = -0.75

ref_data = tx.sample_data.advection_1d_periodic(
    cfl = CFL,
    key = jax.random.PRNGKey(0),
)

linear_conv_kernel_2 = eqx.nn.Conv1d(
    1, 1, 2,
    padding="SAME", padding_mode="CIRCULAR", use_bias=False,
    key=jax.random.PRNGKey(73)
)

sup_1_trainer, sup_5_trainer, sup_20_trainer = (
    tx.trainer.SupervisedTrainer(
        ref_data,
        num_rollout_steps=r,
        optimizer=optax.adam(1e-2),
        num_training_steps=1000,
        batch_size=32,
    )
    for r in (1, 5, 20)
)

sup_1_conv, sup_1_loss_history = sup_1_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_5_conv, sup_5_loss_history = sup_5_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_20_conv, sup_20_loss_history = sup_20_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)

FOU_STENCIL = jnp.array([1+CFL, -CFL])

print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL))   # 0.033
print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL))   # 0.025
print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL))  # 0.017

Increasing the supervised unrolling steps during training makes the learned stencil come closer to the numerical FOU stencil.

Background

After the discretization of space and time, the simulation of a time-dependent partial differential equation amounts to the repeated application of a simulation operator $\mathcal{P}h$. Here, we are interested in imitating/emulating this physical/numerical operator with a neural network $f\theta$. This repository is concerned with an abstract implementation of all ways we can frame a learning problem to inject "knowledge" from $\mathcal{P}h$ into $f\theta$.

Assume we have a distribution of initial conditions $\mathcal{Q}$ from which we sample $S$ initial conditions, $u^{[0]} \propto \mathcal{Q}$. Then, we can save them in an array of shape $(S, C, *N)$ (with C channels and an arbitrary number of spatial axes of dimension N) and repeatedly apply $\mathcal{P}$ to obtain the training trajectory of shape $(S, T+1, C, *N)$.

For a one-step supervised learning task, we substack the training trajectory into windows of size $2$ and merge the two leftover batch axes to get a data array of shape $(S \cdot T, 2, N)$ that can be used in supervised learning scenario

$$ L(\theta) = \mathbb{E}{(u^{[0]}, u^{[1]}) \sim \mathcal{Q}} \left[ l\left( f\theta(u^{[0]}), u^{[1]} \right) \right] $$

where $l$ is a time-level loss. In the easiest case $l = \text{MSE}$.

Trainax supports way more than just one-step supervised learning, e.g., to train with unrolled steps, to include the reference simulator $\mathcal{P}_h$ in training, train on residuum conditions instead of resolved reference states, cut and modify the gradient flow, etc.

Features

  • Wide collection of unrolled training methodologies:
    • Supervised
    • Diverted Chain
    • Mix Chain
    • Residuum
  • Based on JAX:
    • One of the best Automatic Differentiation engines (forward & reverse)
    • Automatic vectorization
    • Backend-agnostic code (run on CPU, GPU, and TPU)
  • Build on top and compatible with Equinox
  • Batch-Parallel Training
  • Collection of Callbacks
  • Composability

A Taxonomy of Training Methodologies

The major axes that need to be chosen are:

  • The unrolled length (how often the network is applied autoregressively on the input)
  • The branch length (how long the reference goes alongside the network; we get full supervised if that is as long as the rollout length)
  • Whether the physics is resolved (diverted-chain and supervised) or only given as a condition (residuum-based loss)

Additional axes are:

  • The time level loss (how two states are compared, or a residuum state is reduced)
  • The time level weights (if there is network rollout, shall states further away from the initial condition be weighted differently (like exponential discounting in reinforcement learning))
  • If the main chain of network rollout is interleaved with a physics solver (-> mix chain)
  • Modifications to the gradient flow:
    • Cutting the backpropagation through time in the main chain (after each step, or sparse)
    • Cutting the diverted physics
    • Cutting the one or both levels of the inputs to a residuum function.

Implementation details

There are three levels of hierarchy:

  1. The loss submodule defines time-level wise comparisons between two states. A state is either a tensor of shape (num_channels, ...) (with ellipsis indicating an arbitrary number of spatial dim,ensions) or a tensor of shape (num_batches, num_channels, ...). The time level loss is implemented for the former but allows additional vectorized and (mean-)aggregated on the latter. (In the schematic above, the time-level loss is the green circle).
  2. The configuration submodule devises how neural time stepper $f_\theta$ (denoted NN in the schematic) interplays with the numerical simulator $\mathcal{P}_h$. Similar to the time-level loss this is a callable PyTree which requires during calling the neural stepper and some data. What this data contains depends on the concrete configuration. For supervised rollout training it is the batch of (sub-) trajectories to be considered. Other configurations might also require the reference stepper or a two consecutive time level based residuum function. Each configuration is essentially an abstract implementation of the major methodologies (supervised, diverted-chain, mix-chain, residuum). The most general diverted chain implementation contains supervised and branch-one diverted chain as special cases. All configurations allow setting additional constructor arguments to, e.g., cut the backpropagation through time (sparsely) or to supply time-level weightings (for example to exponentially discount contributions over long rollouts).
  3. The training submodule combines a configuration together with stochastic minibatching on a set of reference trajectories. For each configuration, there is a corresponding trainer that essentially is sugarcoating around combining the relevant configuration with the GeneralTrainer and a trajectory substacker.

You can find an overview of predictor learning setups here.

License

MIT, see here


fkoehler.site  ·  GitHub @ceyron  ·  X @felix_m_koehler

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trainax-0.0.1.tar.gz (28.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trainax-0.0.1-py3-none-any.whl (41.2 kB view details)

Uploaded Python 3

File details

Details for the file trainax-0.0.1.tar.gz.

File metadata

  • Download URL: trainax-0.0.1.tar.gz
  • Upload date:
  • Size: 28.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for trainax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 19552dfca2d6f9d7e69963e978628adb19dc2ba9cb9563b510c19e136116c23a
MD5 4438b39e3bec757e440936f8c90a1530
BLAKE2b-256 84a20cabdb3717358951b97594af08cc66386301d403ae98180afe355f8eaad2

See more details on using hashes here.

File details

Details for the file trainax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: trainax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 41.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for trainax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 92fde741245555ed063d094c49870d15ab64c95c15655ef85f25de084a301c87
MD5 7f667b37f72a47222b63d5d50345b77f
BLAKE2b-256 13b11a8e47bf940babbe06c010f6bffb55c8d7f5d158a6fd2f6fd54190b5967a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page