Skip to main content

Fixed-point solvers and differentiation utilities for implicit neural networks in JAX/Equinox.

Project description

anchorax

Fixed-point solvers and differentiation utilities for implicit neural networks in JAX/Equinox.

  • Solvers: Picard, Anderson acceleration, and limited-memory Broyden.
  • Backpropagation: JFB (Jacobian-Free Backprop) and GDEQ-style adjoint preconditioning.

Install

pip install anchorax

Requirements

  • einops>=0.8.1
  • equinox>=0.12.1
  • jax>=0.6.0

Quickstart

Define your implicit update f(u, Qd) as an Equinox module. The solvers will call it as f(u, Qd, inference=..., key=...).

import equinox as eqx
import jax
import jax.numpy as jnp

class ImplicitBlock(eqx.Module):
    w: jnp.ndarray
    b: jnp.ndarray

    def __call__(self, u, Qd, *, inference: bool = False, key=None):
        # example: contractive residual layer
        return jnp.tanh(u @ self.w + Qd + self.b)

# Instantiate and partition into trainable params and static parts
model = ImplicitBlock(
    w=jax.random.normal(jax.random.PRNGKey(0), (128, 128)) * 0.1,
    b=jnp.zeros((128,)),
)
params, static = eqx.partition(model, eqx.is_array)
Qd = jnp.zeros((128,))

Solve with any solver (same API and outputs):

from anchorax.solvers import broyden, picard, anderson

key = jax.random.PRNGKey(42)

u_star, u_prev, tnstep = broyden(params, static, Qd, inference=True, key=key)
# Or:
u_star, u_prev, tnstep = picard(params, static, Qd, inference=True, key=key)
u_star, u_prev, tnstep = anderson(params, static, Qd, inference=True, key=key)

Train end-to-end with GDEQ or JFB:

import optax
from anchorax.backprop import gdeq, jfb

opt = optax.adam(1e-3)
opt_state = opt.init(eqx.filter(params, eqx.is_array))

def loss_fn(train_params, key):
    u_star, _, _ = gdeq((train_params, Qd), static, solver=broyden, inference=False, key=key)
    return jnp.mean(u_star**2)

loss_and_grad = eqx.filter_value_and_grad(loss_fn)

for step in range(1000):
    key, key_step = jax.random.split(key)
    loss, grads = loss_and_grad(params, key_step)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = eqx.apply_updates(params, updates)

API summary

All solvers share the same signature and semantics:

u_star, u_prev, tnstep = solver(
    params,
    static,
    Qd,
    *,
    u0=None,
    eps=1e-3,
    max_depth=100,
    stop_mode="abs",        # "abs" or "rel", on residual g(u)=f(u)-u
    LBFGS_thres=3,          # used by broyden; accepted by others for API uniformity
    ls=False,               # Armijo line search (broyden only)
    return_final=False,     # if False, return best-so-far; if True, last iterate
    return_B=False,         # if True, also return (Us, VTs, valid_mask)
    inference=False,
    key=...,
    **kwargs,
)
  • Returns:
    • u_star: final implicit state (best-so-far unless return_final=True)
    • u_prev: the iterate immediately preceding u_star (used for accurate autodiff)
    • tnstep: total forward evaluations of f performed by the solver
    • Optionally (Us, VTs, valid_mask) when return_B=True (Broyden’s limited-memory inverse-Jacobian factors)

Backprop wrappers:

# JFB: last-step VJP only (baseline)
u_star, u_prev, tnstep = jfb((params, Qd), static, solver=picard, inference=False, key=key)

# GDEQ: last-step VJP + adjoint preconditioning with inverse-Jacobian factors
u_star, u_prev, tnstep = gdeq((params, Qd), static, solver=broyden, inference=False, key=key)
  • With Picard/Anderson plus gdeq, the stored factors act as a no-op preconditioner (degrades to JFB), ensuring compatibility.

Design notes

  • Residual and stopping:
    • Residual is g(u) = f(u, Qd) - u.
    • stop_mode="abs" uses ||g(u)|| < eps.
    • stop_mode="rel" uses ||g(u)|| / (||g(u)+u|| + 1e-9) < eps.
  • Determinism and randomness:
    • Each solver accepts key and folds in the iteration index for every call to f. This ensures deterministic behavior of stochastic layers per-iteration.
  • JIT and static args:
    • The following are treated as static for compilation: static, eps, max_depth, stop_mode, LBFGS_thres, ls, return_final, return_B. Changing them triggers recompilation.
  • Autodiff accuracy:
    • Both jfb and gdeq recompute the final explicit step as u_star = f(stop_gradient(u_prev), Qd) under inference mode and use its VJP for gradients. This aligns training and inference and avoids off-by-one iteration effects.

References

  • Broyden, C. G. (1965). "A Class of Methods for Solving Nonlinear Simultaneous Equations". Mathematics of Computation. 19 (92). American Mathematical Society: 577–593. doi:10.1090/S0025-5718-1965-0198670-6. JSTOR 2003941.
  • Anderson, Donald G. (October 1965). "Iterative Procedures for Nonlinear Integral Equations". Journal of the ACM. 12 (4): 547–560. doi:10.1145/321296.321305.
  • Fung et al. (2021). "JFB: Jacobian-Free Backpropagation for Implicit Networks". arXiv:2103.12803.
  • Nguyen, B. et al. (2023). "Efficient Training of Deep Equilibrium Models". arXiv:2304.11663.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

anchorax-0.1.0.tar.gz (14.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

anchorax-0.1.0-py3-none-any.whl (13.1 kB view details)

Uploaded Python 3

File details

Details for the file anchorax-0.1.0.tar.gz.

File metadata

  • Download URL: anchorax-0.1.0.tar.gz
  • Upload date:
  • Size: 14.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.14

File hashes

Hashes for anchorax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0872ede947b8de99b7f7821e390a738b1f9f6f8208b2afe433899dc33f16129d
MD5 13378428483f6b03828e52504f3734d0
BLAKE2b-256 f098e840d487613464ea54c5c2c33f5f440637a3bcf88612d6d02be1d88cdb78

See more details on using hashes here.

File details

Details for the file anchorax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: anchorax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.14

File hashes

Hashes for anchorax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 28b6a4b294566f1faf8119af4b379129c3f7650958ee32b4637816a26aee4a0c
MD5 69ab0913b4956a1efe95c460223e6407
BLAKE2b-256 2b4df42696bbbe7d52528cbe635979ee63026249e585be3b9f688923629c6d59

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page