Skip to main content

Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX

Project description

nlls_gram

CI Docs PyPI Python versions License: MIT Ruff

Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX.

GramLevenbergMarquardt minimizes ||r(params)||^2 for a user-supplied residual_fn(params, batch), where params is any JAX pytree (a flat array, a dict, nnx.state(model, nnx.Param), ...). It follows an init/update protocol: update(params, state, batch) returns the new params pytree (same structure), the next state, and an LMInfo. For overparameterized systems (many more parameters p than residual rows n) it factors the small n x n gram (dual) system instead of the p x p normal equations.

The solver depends only on jax — it knows nothing about flax/nnx/optax. It performs no float casts: dtypes flow from your params/residual and JAX decides float32 vs float64 via jax_enable_x64.

Install

pip install nlls-gram

Minimal example

Fit y = a * exp(b * x) to noise-free data generated from (a, b) = (2, -1), using a plain dict pytree of parameters:

import jax
import jax.numpy as jnp

from nlls_gram import GramLevenbergMarquardt

jax.config.update("jax_enable_x64", True)


# residual_fn(params, batch) -> 1-D residual array; the solver minimizes its SSQ.
def residual_fn(params, batch):
    x, y = batch
    return params["a"] * jnp.exp(params["b"] * x) - y


x = jnp.linspace(0.0, 2.0, 20)
y = 2.0 * jnp.exp(-1.0 * x)

params = {"a": jnp.asarray(1.0), "b": jnp.asarray(0.0)}
solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()


# The solver does not jit internally; wrap the train step yourself.
@jax.jit
def train_step(params, lm_state, batch):
    return solver.update(params, lm_state, batch)


for _ in range(50):
    params, lm_state, info = train_step(params, lm_state, (x, y))

print(params["a"], params["b"])  # ~2.0, ~-1.0

params can be any pytree. With Flax NNX, pass nnx.state(model, nnx.Param) as params and write residual_fn(state, batch) using nnx.merge; the solver itself stays NNX-agnostic.

Filtering / freezing parameters

update optimizes exactly the params pytree you pass, so freezing is just "pass fewer params": keep the frozen values in residual_fn's closure and hand the solver only the trainable subset. Frozen leaves get no Jacobian column and never move — no wrt/masking argument needed.

# Optimize only "a"; "b" is frozen at its current value.
frozen = {"b": jnp.asarray(-1.0)}


def residual_fn(trainable, batch):
    x, y = batch
    params = {**frozen, **trainable}  # frozen from the closure, trainable optimized
    return params["a"] * jnp.exp(params["b"] * x) - y


trainable = {"a": jnp.asarray(1.0)}
solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
lm_state = solver.init()
for _ in range(50):
    trainable, lm_state, info = solver.update(trainable, lm_state, (x, y))
# trainable["a"] -> ~2.0; "b" stayed -1.0

With Flax NNX, split the model into frozen and trainable states with a filter and merge them back inside residual_fn. freeze_filter is any nnx Filter (a type, path, or predicate) picking the params to hold fixed; ... captures the rest as trainable:

graphdef, frozen, trainable = nnx.split(model, freeze_filter, ...)


def residual_fn(trainable, batch):
    m = nnx.merge(graphdef, frozen, trainable)
    ...  # compute residuals from m


trainable, lm_state, info = solver.update(trainable, lm_state, batch)
new_model = nnx.merge(graphdef, frozen, trainable)

Documentation

Full docs: https://highdimensionaleconlab.github.io/nlls_gram/

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlls_gram-0.2.0.tar.gz (59.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nlls_gram-0.2.0-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file nlls_gram-0.2.0.tar.gz.

File metadata

  • Download URL: nlls_gram-0.2.0.tar.gz
  • Upload date:
  • Size: 59.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for nlls_gram-0.2.0.tar.gz
Algorithm Hash digest
SHA256 e95eac429faa1366eb516ac8788364e6520de26a316a8352d3ed227d87877605
MD5 aa34b9b3eb9170dbb45165007d98cf42
BLAKE2b-256 d9f7eeea7167cb32588e12eac5284967a379bfd1a362bf2916b562cef3c413cc

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlls_gram-0.2.0.tar.gz:

Publisher: publish.yml on HighDimensionalEconLab/nlls_gram

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nlls_gram-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: nlls_gram-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 5.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for nlls_gram-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 19a89e63c2bd04fdb6e74120b3280520fcf69672b499ea92013e70e9b5cc3696
MD5 931c59430d64c34cf3d634619a0e83da
BLAKE2b-256 c32080718707a6650e651f173078d9002bc21d878e88e4f6305d5765be02360e

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlls_gram-0.2.0-py3-none-any.whl:

Publisher: publish.yml on HighDimensionalEconLab/nlls_gram

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page