Jax Neural Operators
Project description
jNO (jax Neural Operators) is a JAX-native library for training neural operators and physics-informed networks. It unifies data-driven operator regression, residual-based PINN training, mesh-aware FEM/variational PINNs, and foundation-model fine-tuning under one symbolic tracing language — write the PDE once, compile once, train, evaluate, and checkpoint without rewriting the surrounding code.
Status: research-level repository under active development. Public API is stabilising but may change between minor versions.
Install
pip install "jax-neural-operators[fem]"
GPU support is included by default (jNO depends on jax[cuda]). See
docs/Installation.md for Pixi, Docker, and
specific-CUDA-version setups.
Foundation models and other neural operator architectures live in a separate repository (foundax), installed automatically as a dependency so they can also be used on their own.
Example
import jno
import jax
import optax
import foundax
dir = jno.setup("./runs/test")
# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")
random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)
# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))
# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0 # PDE Loss
# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callbacks.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])
# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")
# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))
pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)
Citation
If you use jNO in academic work, please cite:
@article{armbruster2026jno,
title = {jNO: A JAX Library for Neural Operator and Foundation Model Training},
author = {Armbruster, Leon and Ramesh, Rathan and Kruse, Georg and Straub, Christopher},
journal = {arXiv preprint arXiv:2605.10159},
year = {2026},
doi = {10.48550/arXiv.2605.10159},
url = {https://arxiv.org/abs/2605.10159}
}
AI Disclosure
Parts of this codebase — including model ports, tests, and documentation — were developed with the assistance of AI coding tools. All contributions are reviewed and tested to the best of our ability, but mistakes may remain; please open an issue if you spot one.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_neural_operators-0.2.3.tar.gz.
File metadata
- Download URL: jax_neural_operators-0.2.3.tar.gz
- Upload date:
- Size: 422.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
823a4b43a99e7cc8113d008a84c18a50160b5f626cc3d6fc30acf833e3fd1df9
|
|
| MD5 |
62bf674f658f594dbe77c1a96e7af2c7
|
|
| BLAKE2b-256 |
904e579946052f4910be11aed059411b5db8130f9008bd981853e1c38b7d2ec3
|
File details
Details for the file jax_neural_operators-0.2.3-py3-none-any.whl.
File metadata
- Download URL: jax_neural_operators-0.2.3-py3-none-any.whl
- Upload date:
- Size: 321.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7edaace6164ed1d53ce4c5b5ba2f1e075e46a31952a434a5880482d72d219f80
|
|
| MD5 |
184fe08b0325f85db227e129336e643a
|
|
| BLAKE2b-256 |
49f150868d37bff8ad533c60391e4b0a925e914f29886a341dce7f97bf8ebc6f
|