Skip to main content

Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX/Flax NNX models

Project description

nlls_gram

CI Docs PyPI Python versions License: MIT Ruff

Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX/Flax NNX models.

GramLevenbergMarquardt minimizes ||r(theta)||^2 for a residual defined over an nnx.Module, following the optax/nnx init/update protocol so that steps apply through nnx.Optimizer(model, optax.identity(), wrt=...). For overparameterized systems (many more parameters p than residual rows n) it factors the small n x n gram (dual) system instead of the p x p normal equations.

Install

pip install nlls-gram

Minimal example

Fit y = a * exp(b * x) to noise-free data generated from (a, b) = (2, -1):

import jax
import jax.numpy as jnp
import optax
from flax import nnx

from nlls_gram import GramLevenbergMarquardt

jax.config.update("jax_enable_x64", True)


class ExpModel(nnx.Module):
    def __init__(self, a, b):
        self.a = nnx.Param(jnp.asarray(a))
        self.b = nnx.Param(jnp.asarray(b))

    def __call__(self, x):
        return self.a * jnp.exp(self.b * x)


x = jnp.linspace(0.0, 2.0, 20)
y = 2.0 * jnp.exp(-1.0 * x)

model = ExpModel(a=1.0, b=0.0)
solver = GramLevenbergMarquardt(lambda m, batch: m(batch[0]) - batch[1])
optimizer = nnx.Optimizer(model, optax.identity(), wrt=nnx.Param)
lm_state = solver.init()


@jax.jit
def train_step(graphdef, state, lm_state, batch):
    m, opt = nnx.merge(graphdef, state)
    updates, lm_state, info = solver.update(m, lm_state, batch)
    opt.update(m, updates)
    return lm_state, info, nnx.state((m, opt))


graphdef, state = nnx.split((model, optimizer))
for _ in range(50):
    lm_state, info, state = train_step(graphdef, state, lm_state, (x, y))
nnx.update((model, optimizer), state)

print(model.a[...], model.b[...])  # ~2.0, ~-1.0

Documentation

Full docs: https://highdimensionaleconlab.github.io/nlls_gram/

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlls_gram-0.1.0.tar.gz (81.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nlls_gram-0.1.0-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

File details

Details for the file nlls_gram-0.1.0.tar.gz.

File metadata

  • Download URL: nlls_gram-0.1.0.tar.gz
  • Upload date:
  • Size: 81.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for nlls_gram-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3ad9966a526cb249755e54dc4be14b1437e4f487dbf42e1b6f78b71a78303097
MD5 01c2dcda704af8e48909fdd6bd45bd8f
BLAKE2b-256 4956c87766a06063c25d833fadd79a2ca8ac5cc8fc440d29ba5c3c66343ad0ac

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlls_gram-0.1.0.tar.gz:

Publisher: publish.yml on HighDimensionalEconLab/nlls_gram

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nlls_gram-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: nlls_gram-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for nlls_gram-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d43e9462f4605834ca00709aadf73993a8b53dc03feec876ca1ba1ae074df3cd
MD5 6005178b2b069c107814f4b19f78ecc6
BLAKE2b-256 02ccf405f71cdd3a7b1647df53723e94a1bfc98562b38aa6927d46faf91932af

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlls_gram-0.1.0-py3-none-any.whl:

Publisher: publish.yml on HighDimensionalEconLab/nlls_gram

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page