Jax Neural Operators
Project description
Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.
Install
Quick install from PyPI:
pip install jax-neural-operators
If a Nvidia GPU is available install
pip instal jax[cuda]
For local development (recommended on Linux aarch64 when gmsh wheels are unavailable on PyPI), use micromamba:
micromamba create -n jno python=3.12 pip -y
micromamba activate jno
micromamba install -n jno -c conda-forge gmsh python-gmsh -y
pip install -e .
Minimal DeepONet Example
Create the following file
import jno
import jax
import optax
import foundax
dir = jno.setup("./runs/test")
# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")
random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)
# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))
# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0 # PDE Loss
# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callback.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])
# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")
# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))
pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)
and then run with
CUDA_VISIBLE_DEVICES=<gpu_id> JNO_SEED=<seed> python <filename>.py
Foundation Models and other neural networks
These models are maintained in a seperate repository (foundax) so they can also be used independently.
pip install foundax
Citation
If jNO is used we would appreciate to cite the following paper:
@article{armbruster2026jNO,
author = {Armbruster, Leon, ....},
title = {{jNO}: A JAX Library for Neural Operator and PDE Foundation Model Training},
journal = {},
year = {},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_neural_operators-0.1.1.tar.gz.
File metadata
- Download URL: jax_neural_operators-0.1.1.tar.gz
- Upload date:
- Size: 260.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba0c59b84d0abcb1cd7d2db85ae34ff47b187160be9c50f7c5b09c0ddb762958
|
|
| MD5 |
a4291fa26c53083782b0f212e7b63b32
|
|
| BLAKE2b-256 |
dded1da6e416f173fa835714beb98d492a7eacec2f6cd37447a28d2535728944
|
File details
Details for the file jax_neural_operators-0.1.1-py3-none-any.whl.
File metadata
- Download URL: jax_neural_operators-0.1.1-py3-none-any.whl
- Upload date:
- Size: 222.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1719a5a5f2c79ea7ea87aa57c78828f2130aba5b97a9369316fc02aa01b22169
|
|
| MD5 |
e1aef37b253b6144d90527fc2ac091c8
|
|
| BLAKE2b-256 |
dd39718ec170c3f2378050f343c6e13abc1aada6a5c2fc3beb965d6962d98d7f
|