Skip to main content

Jax Neural Operators

Project description

jNO logo

Dev Docs Tests License Citation Docker image available

Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.

Install

Quick install from PyPI:

pip install jax-neural-operators

Foundation models and other neural operators are maintained in a separate repository (foundax) so they can also be used independently (foundax is installed automatically with this repository).

Example

import jno
import jax
import jax.numpy as jnp
import optax
import foundax
import equinox as eqx
from jno import LearningRateSchedule as lrs
from jno.numpy import tracker

dir = jno.setup("./runs/poisson2d")

# ── Domain — rect with named boundary sides ────────────────────────────────────
dom        = jno.domain(constructor=jno.domain.rect(mesh_size=0.05, x_range=(0, 1), y_range=(0, 1)))
x,  y,  _  = dom.variable("interior")
xl, yl, _  = dom.variable("left")    # x = 0  →  soft Dirichlet

# ── Network — LoRA adapters on hidden layers, 10× LR on output layer ───────────
net = jno.nn.wrap(
    foundax.mlp(in_features=2, hidden_dims=64, num_layers=4,
                activation=jnp.tanh, key=jax.random.PRNGKey(0))
)
net.lora(rank=4, alpha=1.0)                                            # parameter-efficient training
net.optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 5_000, 1e-5))

all_false = jax.tree_util.tree_map(lambda _: False, net.module)
out_mask  = eqx.tree_at(lambda m: m.output_layer.weight, all_false, True)
net.mask(out_mask).lr(lrs.exponential(1e-2, 0.5, 5_000, 1e-4))       # output layer at 10× LR

# ── Forward pass — hard BCs on right (x=1), bottom (y=0), top (y=1) ───────────
π  = jno.np.pi
u  = net(jno.np.concat([x,  y ], axis=-1)) * (1 - x)  * y  * (1 - y)
ul = net(jno.np.concat([xl, yl], axis=-1)) * (1 - xl) * yl * (1 - yl)

# ── PDE: −∇²u = 2π²sin(πx)sin(πy),  exact u* = sin(πx)sin(πy) ───────────────
pde     = -(u.dd(x) + u.dd(y)) - 2 * π**2 * jno.np.sin(π * x) * jno.np.sin(π * y)
bc_left = ul                                          # soft: u(0, y) = 0

# ── Integral: ∫u dΩ → 4/π² ≈ 0.405 for the exact solution ────────────────────
vol_tracker = tracker(u.integrate(), interval=500)

# ── Gradient alignment — PDE vs. left-BC loss, output-layer params only ───────
J_pde      = pde.mse.grad(net.mask(out_mask))         # ∂L_pde / ∂θ_out
J_bc       = bc_left.mse.grad(net.mask(out_mask))     # ∂L_bc  / ∂θ_out
grad_align = tracker(jno.np.dot(J_pde, J_bc), interval=500)

# ── Solve ──────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc_left.mse, vol_tracker, grad_align], domain=dom)
crux.solve(20_000).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_neural_operators-0.2.2.tar.gz (355.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_neural_operators-0.2.2-py3-none-any.whl (290.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_neural_operators-0.2.2.tar.gz.

File metadata

  • Download URL: jax_neural_operators-0.2.2.tar.gz
  • Upload date:
  • Size: 355.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jax_neural_operators-0.2.2.tar.gz
Algorithm Hash digest
SHA256 8e7610861de4276a1662806e36bc7491f6af1af7e8e56bfab1a984127e6ca0b3
MD5 ae176647d8e089750e16746cf3f52632
BLAKE2b-256 43b32eeb4b104ad3ae99b3f18dcbc8e69dadb41e95f6b49a7064da7dc8620e9e

See more details on using hashes here.

File details

Details for the file jax_neural_operators-0.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_neural_operators-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 cf09decfdfbbafc28a8e442f19cbeddc5695cc93f31a00711adf94bf222010c1
MD5 c37eea8552e590c3261a077ba17c8a1a
BLAKE2b-256 082d0ca85197e0e1ddf2166ddeb005d31b7e5bb888f53ac374670e52cfa34037

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page