Skip to main content

Jax Neural Operators

Project description

jNO logo

Dev Docs Tests Coverage License Citation Docker image available arXiv Paper

jNO (jax Neural Operators) is a JAX-native library for training neural operators and physics-informed networks. It unifies data-driven operator regression, residual-based PINN training, mesh-aware FEM/variational PINNs, and foundation-model fine-tuning under one symbolic tracing language — write the PDE once, compile once, train, evaluate, and checkpoint without rewriting the surrounding code.

Status: research-level repository under active development. Public API is stabilising but may change between minor versions.

Install

pip install "jax-neural-operators[fem]"

GPU support is included by default (jNO depends on jax[cuda]). See docs/Installation.md for Pixi, Docker, and specific-CUDA-version setups.

Foundation models and other neural operator architectures live in a separate repository (foundax), installed automatically as a dependency so they can also be used on their own.

Example

import jno
import jax
import optax
import foundax

dir = jno.setup("./runs/test")

# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")

random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)

# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))

# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0  # PDE Loss

# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callbacks.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])

# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")

# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))

pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)

Citation

If you use jNO in academic work, please cite:

@article{armbruster2026jno,
  title   = {jNO: A JAX Library for Neural Operator and Foundation Model Training},
  author  = {Armbruster, Leon and Ramesh, Rathan and Kruse, Georg and Straub, Christopher},
  journal = {arXiv preprint arXiv:2605.10159},
  year    = {2026},
  doi     = {10.48550/arXiv.2605.10159},
  url     = {https://arxiv.org/abs/2605.10159}
}

AI Disclosure

Parts of this codebase — including model ports, tests, and documentation — were developed with the assistance of AI coding tools. All contributions are reviewed and tested to the best of our ability, but mistakes may remain; please open an issue if you spot one.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_neural_operators-0.2.3.tar.gz (422.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_neural_operators-0.2.3-py3-none-any.whl (321.6 kB view details)

Uploaded Python 3

File details

Details for the file jax_neural_operators-0.2.3.tar.gz.

File metadata

  • Download URL: jax_neural_operators-0.2.3.tar.gz
  • Upload date:
  • Size: 422.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jax_neural_operators-0.2.3.tar.gz
Algorithm Hash digest
SHA256 823a4b43a99e7cc8113d008a84c18a50160b5f626cc3d6fc30acf833e3fd1df9
MD5 62bf674f658f594dbe77c1a96e7af2c7
BLAKE2b-256 904e579946052f4910be11aed059411b5db8130f9008bd981853e1c38b7d2ec3

See more details on using hashes here.

File details

Details for the file jax_neural_operators-0.2.3-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_neural_operators-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7edaace6164ed1d53ce4c5b5ba2f1e075e46a31952a434a5880482d72d219f80
MD5 184fe08b0325f85db227e129336e643a
BLAKE2b-256 49f150868d37bff8ad533c60391e4b0a925e914f29886a341dce7f97bf8ebc6f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page