Jax Neural Operators
Project description
Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.
Install
Quick install from PyPI:
pip install jax-neural-operators
Foundation models and other neural operators are maintained in a separate repository (foundax) so they can also be used independently (foundax is installed automatically with this repository).
Example
import jno
import jax
import jax.numpy as jnp
import optax
import foundax
import equinox as eqx
from jno import LearningRateSchedule as lrs
from jno.numpy import tracker
dir = jno.setup("./runs/poisson2d")
# ── Domain — rect with named boundary sides ────────────────────────────────────
dom = jno.domain(constructor=jno.domain.rect(mesh_size=0.05, x_range=(0, 1), y_range=(0, 1)))
x, y, _ = dom.variable("interior")
xl, yl, _ = dom.variable("left") # x = 0 → soft Dirichlet
# ── Network — LoRA adapters on hidden layers, 10× LR on output layer ───────────
net = jno.nn.wrap(
foundax.mlp(in_features=2, hidden_dims=64, num_layers=4,
activation=jnp.tanh, key=jax.random.PRNGKey(0))
)
net.lora(rank=4, alpha=1.0) # parameter-efficient training
net.optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 5_000, 1e-5))
all_false = jax.tree_util.tree_map(lambda _: False, net.module)
out_mask = eqx.tree_at(lambda m: m.output_layer.weight, all_false, True)
net.mask(out_mask).lr(lrs.exponential(1e-2, 0.5, 5_000, 1e-4)) # output layer at 10× LR
# ── Forward pass — hard BCs on right (x=1), bottom (y=0), top (y=1) ───────────
π = jno.np.pi
u = net(jno.np.concat([x, y ], axis=-1)) * (1 - x) * y * (1 - y)
ul = net(jno.np.concat([xl, yl], axis=-1)) * (1 - xl) * yl * (1 - yl)
# ── PDE: −∇²u = 2π²sin(πx)sin(πy), exact u* = sin(πx)sin(πy) ───────────────
pde = -(u.dd(x) + u.dd(y)) - 2 * π**2 * jno.np.sin(π * x) * jno.np.sin(π * y)
bc_left = ul # soft: u(0, y) = 0
# ── Integral: ∫u dΩ → 4/π² ≈ 0.405 for the exact solution ────────────────────
vol_tracker = tracker(u.integrate(), interval=500)
# ── Gradient alignment — PDE vs. left-BC loss, output-layer params only ───────
J_pde = pde.mse.grad(net.mask(out_mask)) # ∂L_pde / ∂θ_out
J_bc = bc_left.mse.grad(net.mask(out_mask)) # ∂L_bc / ∂θ_out
grad_align = tracker(jno.np.dot(J_pde, J_bc), interval=500)
# ── Solve ──────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc_left.mse, vol_tracker, grad_align], domain=dom)
crux.solve(20_000).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_neural_operators-0.2.2.tar.gz.
File metadata
- Download URL: jax_neural_operators-0.2.2.tar.gz
- Upload date:
- Size: 355.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8e7610861de4276a1662806e36bc7491f6af1af7e8e56bfab1a984127e6ca0b3
|
|
| MD5 |
ae176647d8e089750e16746cf3f52632
|
|
| BLAKE2b-256 |
43b32eeb4b104ad3ae99b3f18dcbc8e69dadb41e95f6b49a7064da7dc8620e9e
|
File details
Details for the file jax_neural_operators-0.2.2-py3-none-any.whl.
File metadata
- Download URL: jax_neural_operators-0.2.2-py3-none-any.whl
- Upload date:
- Size: 290.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cf09decfdfbbafc28a8e442f19cbeddc5695cc93f31a00711adf94bf222010c1
|
|
| MD5 |
c37eea8552e590c3261a077ba17c8a1a
|
|
| BLAKE2b-256 |
082d0ca85197e0e1ddf2166ddeb005d31b7e5bb888f53ac374670e52cfa34037
|