Skip to main content

One-stop JAX foundation model repository

Project description

foundax

foundax logo

License Hugging Face

Unified JAX model zoo for operator learning, PDE surrogates, and Equinox foundation-model wrappers.

pip install foundax

Overview

foundax provides two main model groups:

  • Core Equinox architectures in foundax/architectures/ (FNO, UNet, DeepONet, GNOT family, and others)
  • Equinox wrappers for larger vendored model families (Poseidon, MORPH, MPP, Walrus, BCAT, PDEformer-2, DPOT, PROSE)

Quick Start

import foundax as fx

# Core models
model = fx.mlp(in_features=2, output_dim=1, hidden_dims=64, num_layers=3)
model = fx.fno2d(in_features=1, hidden_channels=32, n_modes=16)
model = fx.unet2d(in_channels=1, out_channels=1)
model = fx.deeponet(branch_type="mlp", trunk_type="mlp")

# Foundation wrappers (namespace style)
model = fx.poseidon.T()   # T/B/L
model = fx.morph.S()      # Ti/S/M/L
model = fx.mpp.B(n_states=12)  # Ti/S/B/L
model = fx.walrus.base()
model = fx.bcat.base()
model = fx.pdeformer2.small()  # small/base/fast
model = fx.dpot.Ti()      # Ti/S/M/L/H
model, variables = fx.prose.fd_1to1()

Composable Pipe API

Wrap any model or layer with fx.block() and chain them with |. Channel mismatches are caught at construction time with a clear error message.

import jax
import foundax as fx

ks = jax.random.split(jax.random.PRNGKey(0), 8)

# ── Build a 2-D FNO-style pipeline from individual spectral layers ──────────
lift    = fx.block(fx.layers.SpectralBlock2d(1,  32, n_modes=16, key=ks[0]), name="lift")
s1      = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[1]))
s2      = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[2]))
s3      = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[3]))
project = fx.block(fx.layers.SpectralBlock2d(32,  1, n_modes=16, key=ks[4]), name="project")

model = lift | s1 | s2 | s3 | project   # Pipe of 5 blocks

# ── Existing full models work as blocks too ──────────────────────────────────
encoder = fx.block(fx.fno2d(in_features=3, hidden_channels=32, n_modes=16, key=ks[5]))
decoder = fx.block(fx.layers.SpectralBlock2d(32, 1, n_modes=16, key=ks[6]))

model = encoder | decoder

# ── Multi-input combinators (DeepONet-style) ─────────────────────────────────
branch = (
    fx.block(fx.layers.SpectralBlock1d(1, 32, n_modes=16, key=ks[0]))
    | fx.block(fx.mlp(in_features=32, output_dim=64, hidden_dims=64, key=ks[1]))
)
trunk = fx.block(fx.mlp(in_features=2, output_dim=64, hidden_dims=64, key=ks[2]))

model = fx.dot(branch, trunk)   # branch(u) · trunk(y)  →  (N_pts,)

# Also available: fx.add(a, b)  — elementwise sum of two branches
#                 fx.cat(a, b)  — concatenate outputs along the channel axis

# ── All pipe models are plain Equinox modules ────────────────────────────────
import equinox as eqx, optax, jax.numpy as jnp

opt   = optax.adam(1e-3)
state = opt.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def step(model, state, u, y, target):
    loss, grads = eqx.filter_value_and_grad(
        lambda m: jnp.mean((m(u, y) - target) ** 2)
    )(model)
    updates, state = opt.update(grads, state, eqx.filter(model, eqx.is_array))
    return eqx.apply_updates(model, updates), state, loss

Integration With jNO

import foundax as fx
import jno
import optax

net = jno.nn.wrap(fx.poseidon.T(num_channels=5, num_out_channels=1))
net.optimizer(
    optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adamw(
            learning_rate=optax.schedules.warmup_cosine_decay_schedule(
                init_value=1e-7,
                peak_value=1e-3,
                warmup_steps=500,
                decay_steps=10000,
                end_value=1e-6,
            ),
            weight_decay=1e-4,
        ),
    )
)
net.initialize('./poseidonT.eqx')
net.mask(param_mask).lora(rank=4)

Notes

  • Top-level convenience aliases are still available (for example fx.poseidonT()), but namespace-style access is recommended for readability.
  • Foundation-model wrappers are documented in detail in docs/equinox-architectures.md.

License

This project is licensed under the MIT License.

Foundation models remain subject to their original licenses. See THIRD_PARTY_LICENSES for details. Some pretrained weights (for example Poseidon) are released under non-commercial terms.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

foundax-0.2.0.tar.gz (163.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

foundax-0.2.0-py3-none-any.whl (160.3 kB view details)

Uploaded Python 3

File details

Details for the file foundax-0.2.0.tar.gz.

File metadata

  • Download URL: foundax-0.2.0.tar.gz
  • Upload date:
  • Size: 163.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for foundax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 eacfb40ff75e3b27723d6d9a81a0bebaf08038b05ce105270521fd7f8453ca40
MD5 2b153cede4e2bfaf5d8e7fafe86161c7
BLAKE2b-256 e36f1ec8a98bc10776cd8ac3ea2cd8df27505522ce133e0ed02c1f0197b58384

See more details on using hashes here.

File details

Details for the file foundax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: foundax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 160.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for foundax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9240526f8bcf9860033807404c6e2402dfb10a73ed75088409c37aa58a6fc9b2
MD5 234c7a60cdb6550cb9475181614d2e5f
BLAKE2b-256 1afd0f518e6880b401ad8a415701ffab0205240233dc5ab9522e93430e0bc1cf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page