Skip to main content

Jax Neural Operators

Project description

jNO logo

Dev Docs Dev Tutorials Tests License Citation Docker image available

Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.

Install

Quick install from PyPI:

pip install jax-neural-operators

If a Nvidia GPU is available install

pip instal jax[cuda]

For local development (recommended on Linux aarch64 when gmsh wheels are unavailable on PyPI), use micromamba:

micromamba create -n jno python=3.12 pip -y
micromamba activate jno
micromamba install -n jno -c conda-forge gmsh python-gmsh -y
pip install -e .

Minimal DeepONet Example

Create the following file

import jno
import jax
import optax
import foundax

dir = jno.setup("./runs/test")

# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")

random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)

# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))

# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0  # PDE Loss

# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callback.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])

# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")

# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))

pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)

and then run with

CUDA_VISIBLE_DEVICES=<gpu_id> JNO_SEED=<seed> python <filename>.py

Foundation Models and other neural networks

These models are maintained in a seperate repository (foundax) so they can also be used independently.

pip install foundax

Citation

If jNO is used we would appreciate to cite the following paper:

@article{armbruster2026jNO,
  author  = {Armbruster, Leon, ....},
  title   = {{jNO}: A JAX Library for Neural Operator and PDE Foundation Model Training},
  journal = {},
  year    = {},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_neural_operators-0.2.0.tar.gz (256.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_neural_operators-0.2.0-py3-none-any.whl (219.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_neural_operators-0.2.0.tar.gz.

File metadata

  • Download URL: jax_neural_operators-0.2.0.tar.gz
  • Upload date:
  • Size: 256.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jax_neural_operators-0.2.0.tar.gz
Algorithm Hash digest
SHA256 8ec088b83aa83b403731f0307cd92ac8fff361336eb1d588bcdc1e9d0b421284
MD5 a3ad7ced1965b124fc35b2a48a8659af
BLAKE2b-256 60abf0ad04cd5bd0d616fe8f5d065ab384480b40e73beeb89a0e97f13c012770

See more details on using hashes here.

File details

Details for the file jax_neural_operators-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_neural_operators-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e57d6fa4cff429bcecf38b690918d2c172a1edcd73b7597a212e0de40c73f6a3
MD5 c5eacf31f79b382b30203ee8459cc98b
BLAKE2b-256 4fc0e470d32ff9a99d9ae1315b26dc000aa311d2e8f2f2b78cd329badc7b739a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page