Fixed-point solvers and differentiation utilities for implicit neural networks in JAX/Equinox.
Project description
anchorax
Fixed-point solvers and differentiation utilities for implicit neural networks in JAX/Equinox.
- Solvers: Picard, Anderson acceleration, and limited-memory Broyden.
- Backpropagation: JFB (Jacobian-Free Backprop) and GDEQ-style adjoint preconditioning.
Install
pip install anchorax
Requirements
- einops>=0.8.1
- equinox>=0.12.1
- jax>=0.6.0
Quickstart
Define your implicit update f(u, Qd) as an Equinox module. The solvers will call it as f(u, Qd, inference=..., key=...).
import equinox as eqx
import jax
import jax.numpy as jnp
class ImplicitBlock(eqx.Module):
w: jnp.ndarray
b: jnp.ndarray
def __call__(self, u, Qd, *, inference: bool = False, key=None):
# example: contractive residual layer
return jnp.tanh(u @ self.w + Qd + self.b)
# Instantiate and partition into trainable params and static parts
model = ImplicitBlock(
w=jax.random.normal(jax.random.PRNGKey(0), (128, 128)) * 0.1,
b=jnp.zeros((128,)),
)
params, static = eqx.partition(model, eqx.is_array)
Qd = jnp.zeros((128,))
Solve with any solver (same API and outputs):
from anchorax.solvers import broyden, picard, anderson
key = jax.random.PRNGKey(42)
u_star, u_prev, tnstep = broyden(params, static, Qd, inference=True, key=key)
# Or:
u_star, u_prev, tnstep = picard(params, static, Qd, inference=True, key=key)
u_star, u_prev, tnstep = anderson(params, static, Qd, inference=True, key=key)
Train end-to-end with GDEQ or JFB:
import optax
from anchorax.backprop import gdeq, jfb
opt = optax.adam(1e-3)
opt_state = opt.init(eqx.filter(params, eqx.is_array))
def loss_fn(train_params, key):
u_star, _, _ = gdeq((train_params, Qd), static, solver=broyden, inference=False, key=key)
return jnp.mean(u_star**2)
loss_and_grad = eqx.filter_value_and_grad(loss_fn)
for step in range(1000):
key, key_step = jax.random.split(key)
loss, grads = loss_and_grad(params, key_step)
updates, opt_state = opt.update(grads, opt_state, params)
params = eqx.apply_updates(params, updates)
API summary
All solvers share the same signature and semantics:
u_star, u_prev, tnstep = solver(
params,
static,
Qd,
*,
u0=None,
eps=1e-3,
max_depth=100,
stop_mode="abs", # "abs" or "rel", on residual g(u)=f(u)-u
LBFGS_thres=3, # used by broyden; accepted by others for API uniformity
ls=False, # Armijo line search (broyden only)
return_final=False, # if False, return best-so-far; if True, last iterate
return_B=False, # if True, also return (Us, VTs, valid_mask)
inference=False,
key=...,
**kwargs,
)
- Returns:
u_star: final implicit state (best-so-far unlessreturn_final=True)u_prev: the iterate immediately precedingu_star(used for accurate autodiff)tnstep: total forward evaluations offperformed by the solver- Optionally
(Us, VTs, valid_mask)whenreturn_B=True(Broyden’s limited-memory inverse-Jacobian factors)
Backprop wrappers:
# JFB: last-step VJP only (baseline)
u_star, u_prev, tnstep = jfb((params, Qd), static, solver=picard, inference=False, key=key)
# GDEQ: last-step VJP + adjoint preconditioning with inverse-Jacobian factors
u_star, u_prev, tnstep = gdeq((params, Qd), static, solver=broyden, inference=False, key=key)
- With Picard/Anderson plus
gdeq, the stored factors act as a no-op preconditioner (degrades to JFB), ensuring compatibility.
Design notes
- Residual and stopping:
- Residual is
g(u) = f(u, Qd) - u. stop_mode="abs"uses||g(u)|| < eps.stop_mode="rel"uses||g(u)|| / (||g(u)+u|| + 1e-9) < eps.
- Residual is
- Determinism and randomness:
- Each solver accepts
keyand folds in the iteration index for every call tof. This ensures deterministic behavior of stochastic layers per-iteration.
- Each solver accepts
- JIT and static args:
- The following are treated as static for compilation:
static,eps,max_depth,stop_mode,LBFGS_thres,ls,return_final,return_B. Changing them triggers recompilation.
- The following are treated as static for compilation:
- Autodiff accuracy:
- Both
jfbandgdeqrecompute the final explicit step asu_star = f(stop_gradient(u_prev), Qd)under inference mode and use its VJP for gradients. This aligns training and inference and avoids off-by-one iteration effects.
- Both
References
- Broyden, C. G. (1965). "A Class of Methods for Solving Nonlinear Simultaneous Equations". Mathematics of Computation. 19 (92). American Mathematical Society: 577–593. doi:10.1090/S0025-5718-1965-0198670-6. JSTOR 2003941.
- Anderson, Donald G. (October 1965). "Iterative Procedures for Nonlinear Integral Equations". Journal of the ACM. 12 (4): 547–560. doi:10.1145/321296.321305.
- Fung et al. (2021). "JFB: Jacobian-Free Backpropagation for Implicit Networks". arXiv:2103.12803.
- Nguyen, B. et al. (2023). "Efficient Training of Deep Equilibrium Models". arXiv:2304.11663.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file anchorax-0.1.0.tar.gz.
File metadata
- Download URL: anchorax-0.1.0.tar.gz
- Upload date:
- Size: 14.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0872ede947b8de99b7f7821e390a738b1f9f6f8208b2afe433899dc33f16129d
|
|
| MD5 |
13378428483f6b03828e52504f3734d0
|
|
| BLAKE2b-256 |
f098e840d487613464ea54c5c2c33f5f440637a3bcf88612d6d02be1d88cdb78
|
File details
Details for the file anchorax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: anchorax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
28b6a4b294566f1faf8119af4b379129c3f7650958ee32b4637816a26aee4a0c
|
|
| MD5 |
69ab0913b4956a1efe95c460223e6407
|
|
| BLAKE2b-256 |
2b4df42696bbbe7d52528cbe635979ee63026249e585be3b9f688923629c6d59
|