Skip to main content

Jax Neural Operators

Project description

jNO logo

Dev Docs Dev Tutorials Tests License Citation Docker image available

Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.

Install

Quick install from PyPI:

pip install jax-neural-operators

If a Nvidia GPU is available install

pip instal jax[cuda]

For local development (recommended on Linux aarch64 when gmsh wheels are unavailable on PyPI), use micromamba:

micromamba create -n jno python=3.12 pip -y
micromamba activate jno
micromamba install -n jno -c conda-forge gmsh python-gmsh -y
pip install -e .

Minimal DeepONet Example

Create the following file

import jno
import jax
import optax
import foundax

dir = jno.setup("./runs/test")

# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")

random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)

# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))

# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0  # PDE Loss

# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callback.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])

# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")

# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))

pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)

and then run with

CUDA_VISIBLE_DEVICES=<gpu_id> JNO_SEED=<seed> python <filename>.py

Foundation Models and other neural networks

These models are maintained in a seperate repository (foundax) so they can also be used independently.

pip install foundax

Citation

If jNO is used we would appreciate to cite the following paper:

@article{armbruster2026jNO,
  author  = {Armbruster, Leon, ....},
  title   = {{jNO}: A JAX Library for Neural Operator and PDE Foundation Model Training},
  journal = {},
  year    = {},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_neural_operators-0.2.1.tar.gz (297.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_neural_operators-0.2.1-py3-none-any.whl (263.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_neural_operators-0.2.1.tar.gz.

File metadata

  • Download URL: jax_neural_operators-0.2.1.tar.gz
  • Upload date:
  • Size: 297.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jax_neural_operators-0.2.1.tar.gz
Algorithm Hash digest
SHA256 ec915b55825b66881ea00f1c0b684b544258413e4603e9c104dc2f36ac8c0bb8
MD5 d270d511f02325fa9e5eabc5d2fd4661
BLAKE2b-256 5c5c55ac7174e1fa62d4c5c9a4f2da54c61a103e891b73d859199bd23b912a60

See more details on using hashes here.

File details

Details for the file jax_neural_operators-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_neural_operators-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6a6639133b09dac4b2e0149ad1a80da679c15b9b6cf10426b2f3dfcc55c77744
MD5 9543f277efeb3bfff2d63d413b1102fd
BLAKE2b-256 3721be2db88a4bba28a305b4c41c6d5db8ae707a22d7aaa153df0d1d565b3b4f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page