Skip to main content

Jax Neural Operators

Project description

jNO logo

Dev Docs Dev Tutorials Tests License Citation Docker image available

Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.

Install

Quick install from PyPI:

pip install jax-neural-operators

If a Nvidia GPU is available install

pip instal jax[cuda]

For local development (recommended on Linux aarch64 when gmsh wheels are unavailable on PyPI), use micromamba:

micromamba create -n jno python=3.12 pip -y
micromamba activate jno
micromamba install -n jno -c conda-forge gmsh python-gmsh -y
pip install -e .

Minimal DeepONet Example

Create the following file

import jno
import jax
import optax
import foundax

dir = jno.setup("./runs/test")

# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")

random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)

# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))

# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0  # PDE Loss

# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callback.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])

# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")

# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))

pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)

and then run with

CUDA_VISIBLE_DEVICES=<gpu_id> JNO_SEED=<seed> python <filename>.py

Foundation Models and other neural networks

These models are maintained in a seperate repository (foundax) so they can also be used independently.

pip install foundax

Citation

If jNO is used we would appreciate to cite the following paper:

@article{armbruster2026jNO,
  author  = {Armbruster, Leon, ....},
  title   = {{jNO}: A JAX Library for Neural Operator and PDE Foundation Model Training},
  journal = {},
  year    = {},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_neural_operators-0.1.1.tar.gz (260.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_neural_operators-0.1.1-py3-none-any.whl (222.4 kB view details)

Uploaded Python 3

File details

Details for the file jax_neural_operators-0.1.1.tar.gz.

File metadata

  • Download URL: jax_neural_operators-0.1.1.tar.gz
  • Upload date:
  • Size: 260.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jax_neural_operators-0.1.1.tar.gz
Algorithm Hash digest
SHA256 ba0c59b84d0abcb1cd7d2db85ae34ff47b187160be9c50f7c5b09c0ddb762958
MD5 a4291fa26c53083782b0f212e7b63b32
BLAKE2b-256 dded1da6e416f173fa835714beb98d492a7eacec2f6cd37447a28d2535728944

See more details on using hashes here.

File details

Details for the file jax_neural_operators-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_neural_operators-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1719a5a5f2c79ea7ea87aa57c78828f2130aba5b97a9369316fc02aa01b22169
MD5 e1aef37b253b6144d90527fc2ac091c8
BLAKE2b-256 dd39718ec170c3f2378050f343c6e13abc1aada6a5c2fc3beb965d6962d98d7f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page