JAX/Flax/Optax optimizer manager
Project description
OptTx
Research Code: Co-developed with Claude Code, Gemini CLI, Codex CLI, and Cursor. No guarantees provided. Use at your own risk.
JAX/Flax/Optax optimizer library for PINNs and second-order methods.
Features
- Multi-term objectives:
ObjectivewithTermSpecfor PINNs (PDE, BC, IC terms) - First-order optimizers: Adam, SGD, AdamW, SOAP, MUON, Shampoo, L-BFGS
- Second-order optimizers: CGOptimizer (Fisher/GGN), CROptimizer (Hessian)
- Acceleration methods: TGS, NLTGCR, Anderson Acceleration (AA)
- Graph neural networks: GCN, GAT layers for node classification
- Matrix-free curvature:
build_hessian_matvec,build_fisher_matvec - JIT-stable: Works with
jax.jitandjax.lax.scan
Install
pip install opttx
For development:
pip install -e .[dev]
Quickstart
First-order optimizer
import jax
import jax.numpy as jnp
from flax import linen as nn
from opttx import Adam, Objective, TermSpec, TrainState
# Define model
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(32)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
# Define loss
def mse_loss(pred, batch):
x, y = batch
return jnp.mean((pred - y) ** 2)
# Create objective
term = TermSpec(name="mse", batch_key="data", loss_fn=mse_loss)
objective = Objective(terms=[term])
# Initialize
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 3)))["params"]
state = TrainState(
step=jnp.array(0),
params=params,
opt_state=None,
apply_fn=lambda v, b: model.apply({"params": v["params"]}, b[0]),
)
# Create optimizer and train
optimizer = Adam(objective, learning_rate=1e-3)
state = optimizer.init(state)
batch = {"data": (jnp.ones((8, 3)), jnp.zeros((8, 1)))}
state, metrics = optimizer.step(state, batch)
print(f"Loss: {metrics['loss']}")
Second-order optimizer (CR + Hessian)
from opttx import CROptimizer
optimizer = CROptimizer(
objective,
learning_rate=1.0,
damping=1e-3,
cr_iters=10,
curvature_type="hessian", # or "fisher"
)
state = optimizer.init(state)
state, metrics = optimizer.step(state, batch)
Multi-term objective (PINNs)
def pde_loss(pred, batch):
return jnp.mean(pred ** 2)
def bc_loss(pred, batch):
return jnp.mean(pred ** 2)
pde_term = TermSpec(name="pde", batch_key="x_pde", loss_fn=pde_loss)
bc_term = TermSpec(name="bc", batch_key="x_bc", loss_fn=bc_loss)
objective = Objective(
terms=[pde_term, bc_term],
loss_weights={"pde": 1.0, "bc": 0.1},
)
batch = {
"x_pde": jnp.ones((100, 2)),
"x_bc": jnp.ones((20, 2)),
}
API Reference
Optimizers
| Optimizer | Description |
|---|---|
Adam |
Adam optimizer |
SGD |
SGD with momentum |
AdamW |
Adam with weight decay |
SOAP |
Second-order approximation |
MUON |
Momentum with orthogonalization |
Shampoo |
Shampoo preconditioner |
LBFGSOptimizer |
L-BFGS quasi-Newton |
CGOptimizer |
Conjugate Gradient (Fisher/GGN) |
CROptimizer |
Conjugate Residual (Hessian) |
TGSOptimizer |
TGS acceleration |
TGSAccelerator |
TGS wrapper for any optimizer |
AAAccelerator |
Anderson Acceleration wrapper |
NLTGCROptimizer |
Nonlinear truncated GCR |
Curvature
| Function | Description |
|---|---|
build_hessian_matvec |
Matrix-free Hessian-vector product |
build_fisher_matvec |
Matrix-free Fisher/GGN-vector product |
build_damped_matvec |
Add damping: (H + λI)v |
Solvers
| Function | Description |
|---|---|
cg_solve |
Conjugate Gradient solver |
cr_solve |
Conjugate Residual solver |
tgs_solve_fori |
TGS solver (JIT-compatible) |
nltgcr_solve_fori |
NLTGCR solver (JIT-compatible) |
Models
| Model | Description |
|---|---|
GCN |
Graph Convolutional Network |
GCNLayer |
Single GCN layer |
GAT |
Graph Attention Network |
GATLayer |
Single GAT layer |
normalize_adjacency |
Symmetric adjacency normalization |
Design Constraints
state.stepmust be a scalarjax.Array(never Python int)- Metrics have static string keys and scalar values
- Must include
"loss"key in metrics - Multi-term +
batch_statsis not supported
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
opttx-0.1.0a1.tar.gz
(49.6 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
opttx-0.1.0a1-py3-none-any.whl
(74.8 kB
view details)
File details
Details for the file opttx-0.1.0a1.tar.gz.
File metadata
- Download URL: opttx-0.1.0a1.tar.gz
- Upload date:
- Size: 49.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
46e820e212d1f2fd22f39fe5bd6c7718853de7d7079b838449ac567182c075af
|
|
| MD5 |
2e321a81f4df3acd5589d07b054d6d81
|
|
| BLAKE2b-256 |
9444bd4bfc4e62202e05e0221c002540c7ef261af897769ce17bb8241355800f
|
File details
Details for the file opttx-0.1.0a1-py3-none-any.whl.
File metadata
- Download URL: opttx-0.1.0a1-py3-none-any.whl
- Upload date:
- Size: 74.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a1c38ff25171869b0f8fafd9a30eb8cb89b3767bdae01cc171d8d7bf23d3fabf
|
|
| MD5 |
cbc46ec120a71e3c4b11d133c41d6d15
|
|
| BLAKE2b-256 |
974e8cdcb3b90559d60fdf91f2514ee00228c20ef499d8f41529f89c35083aca
|