Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX
Project description
nlls_gram
Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX.
GramLevenbergMarquardt minimizes ||r(params)||^2 for a user-supplied
residual_fn(params, batch), where params is any JAX pytree (a flat array, a
dict, nnx.state(model, nnx.Param), ...). It follows an init/update protocol:
update(params, state, batch) returns the new params pytree (same structure),
the next state, and an LMInfo. For overparameterized systems (many more
parameters p than residual rows n) it factors the small n x n gram (dual)
system instead of the p x p normal equations.
The solver depends only on jax — it knows nothing about flax/nnx/optax. It
performs no float casts: dtypes flow from your params/residual and JAX decides
float32 vs float64 via jax_enable_x64.
Install
pip install nlls-gram
Minimal example
Fit y = a * exp(b * x) to noise-free data generated from (a, b) = (2, -1),
using a plain dict pytree of parameters:
import jax
import jax.numpy as jnp
from nlls_gram import GramLevenbergMarquardt
jax.config.update("jax_enable_x64", True)
# residual_fn(params, batch) -> 1-D residual array; the solver minimizes its SSQ.
def residual_fn(params, batch):
x, y = batch
return params["a"] * jnp.exp(params["b"] * x) - y
x = jnp.linspace(0.0, 2.0, 20)
y = 2.0 * jnp.exp(-1.0 * x)
params = {"a": jnp.asarray(1.0), "b": jnp.asarray(0.0)}
solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()
# The solver does not jit internally; wrap the train step yourself.
@jax.jit
def train_step(params, lm_state, batch):
return solver.update(params, lm_state, batch)
for _ in range(50):
params, lm_state, info = train_step(params, lm_state, (x, y))
print(params["a"], params["b"]) # ~2.0, ~-1.0
params can be any pytree. With Flax NNX, pass nnx.state(model, nnx.Param) as
params and write residual_fn(state, batch) using nnx.merge; the solver itself
stays NNX-agnostic.
Filtering / freezing parameters
update optimizes exactly the params pytree you pass, so freezing is just
"pass fewer params": keep the frozen values in residual_fn's closure and hand
the solver only the trainable subset. Frozen leaves get no Jacobian column and
never move — no wrt/masking argument needed.
# Optimize only "a"; "b" is frozen at its current value.
frozen = {"b": jnp.asarray(-1.0)}
def residual_fn(trainable, batch):
x, y = batch
params = {**frozen, **trainable} # frozen from the closure, trainable optimized
return params["a"] * jnp.exp(params["b"] * x) - y
trainable = {"a": jnp.asarray(1.0)}
solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()
for _ in range(50):
trainable, lm_state, info = solver.update(trainable, lm_state, (x, y))
# trainable["a"] -> ~2.0; "b" stayed -1.0
With Flax NNX, split the model into frozen and trainable states with a filter and
merge them back inside residual_fn. freeze_filter is any nnx Filter (a type,
path, or predicate) picking the params to hold fixed; ... captures the rest as
trainable:
graphdef, frozen, trainable = nnx.split(model, freeze_filter, ...)
def residual_fn(trainable, batch):
m = nnx.merge(graphdef, frozen, trainable)
... # compute residuals from m
trainable, lm_state, info = solver.update(trainable, lm_state, batch)
new_model = nnx.merge(graphdef, frozen, trainable)
Documentation
Full docs: https://highdimensionaleconlab.github.io/nlls_gram/
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nlls_gram-0.2.0.tar.gz.
File metadata
- Download URL: nlls_gram-0.2.0.tar.gz
- Upload date:
- Size: 59.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e95eac429faa1366eb516ac8788364e6520de26a316a8352d3ed227d87877605
|
|
| MD5 |
aa34b9b3eb9170dbb45165007d98cf42
|
|
| BLAKE2b-256 |
d9f7eeea7167cb32588e12eac5284967a379bfd1a362bf2916b562cef3c413cc
|
Provenance
The following attestation bundles were made for nlls_gram-0.2.0.tar.gz:
Publisher:
publish.yml on HighDimensionalEconLab/nlls_gram
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlls_gram-0.2.0.tar.gz -
Subject digest:
e95eac429faa1366eb516ac8788364e6520de26a316a8352d3ed227d87877605 - Sigstore transparency entry: 1846873388
- Sigstore integration time:
-
Permalink:
HighDimensionalEconLab/nlls_gram@fcb3e898ed73ed944330da92e19f35fdad01943e -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/HighDimensionalEconLab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@fcb3e898ed73ed944330da92e19f35fdad01943e -
Trigger Event:
release
-
Statement type:
File details
Details for the file nlls_gram-0.2.0-py3-none-any.whl.
File metadata
- Download URL: nlls_gram-0.2.0-py3-none-any.whl
- Upload date:
- Size: 5.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
19a89e63c2bd04fdb6e74120b3280520fcf69672b499ea92013e70e9b5cc3696
|
|
| MD5 |
931c59430d64c34cf3d634619a0e83da
|
|
| BLAKE2b-256 |
c32080718707a6650e651f173078d9002bc21d878e88e4f6305d5765be02360e
|
Provenance
The following attestation bundles were made for nlls_gram-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on HighDimensionalEconLab/nlls_gram
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlls_gram-0.2.0-py3-none-any.whl -
Subject digest:
19a89e63c2bd04fdb6e74120b3280520fcf69672b499ea92013e70e9b5cc3696 - Sigstore transparency entry: 1846873499
- Sigstore integration time:
-
Permalink:
HighDimensionalEconLab/nlls_gram@fcb3e898ed73ed944330da92e19f35fdad01943e -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/HighDimensionalEconLab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@fcb3e898ed73ed944330da92e19f35fdad01943e -
Trigger Event:
release
-
Statement type: