Skip to main content

GPU/TPU accelerated nonlinear least-squares curve fitting using JAX

Project description

NLSQ logo

NLSQ: GPU-Accelerated Curve Fitting

Drop-in replacement for scipy.optimize.curve_fit with 150-270x speedup on GPU

PyPI version Documentation Python 3.12+ JAX 0.8.0 License: MIT

Documentation | Examples | API Reference | ArXiv Paper


What is NLSQ?

NLSQ is a nonlinear least squares curve fitting library built on JAX. It provides:

  • SciPy-compatible API - Same function signatures as scipy.optimize.curve_fit
  • GPU/TPU acceleration - JIT-compiled kernels via XLA
  • Automatic differentiation - No manual Jacobian calculations needed
  • Large dataset support - Handles 100M+ data points with streaming optimization

Installation

# CPU (all platforms)
pip install nlsq

# GPU (Linux with CUDA 12.1+)
pip install nlsq
pip install "jax[cuda12-local]==0.8.0"
Verify GPU installation
python -c "import jax; print('Devices:', jax.devices())"
# Expected: [cuda(id=0)] for GPU, [CpuDevice(id=0)] for CPU

Quick Start

import numpy as np
import jax.numpy as jnp
from nlsq import curve_fit


# Define model function (use jax.numpy for GPU acceleration)
def exponential(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# Generate data
x = np.linspace(0, 4, 1000)
y = 2.5 * np.exp(-0.5 * x) + 1.0 + 0.1 * np.random.randn(len(x))

# Fit - same API as scipy.optimize.curve_fit
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5, 1.0])

print(f"Parameters: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}")
print(f"Uncertainties: {np.sqrt(np.diag(pcov))}")

Key Features

Feature Description
Automatic Jacobians JAX's autodiff eliminates manual derivatives
Bounded optimization Trust Region Reflective and Levenberg-Marquardt
Large datasets Chunked and streaming optimizers for 100M+ points
Multi-start Global optimization with LHS/Sobol sampling
Mixed precision Automatic float32→float64 upgrade when needed
Workflow system Auto-selects strategy based on dataset size

Performance

NLSQ shows increasing speedups as dataset size grows:

Dataset Size SciPy (CPU) NLSQ (GPU) Speedup
10K points 15ms 2ms 7x
100K points 150ms 3ms 50x
1M points 1.5s 8ms 190x
10M points 15s 55ms 270x

Note: First call includes JIT compilation (~500ms). Subsequent calls reuse compiled kernels.

See Performance Guide for benchmarks.

Large Dataset Example

from nlsq import fit
import numpy as np
import jax.numpy as jnp


def model(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# 50 million points
x = np.linspace(0, 10, 50_000_000)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))

# Auto-selects optimal strategy (chunked/streaming) based on memory
popt, pcov = fit(model, x, y, p0=[2.5, 0.6, 0.2], show_progress=True)

Advanced Usage

Multi-start global optimization
from nlsq import curve_fit

popt, pcov = curve_fit(
    model,
    x,
    y,
    p0=[1, 1, 1],
    bounds=([0, 0, 0], [10, 5, 10]),
    multistart=True,
    n_starts=20,
    sampler="lhs",
)
Workflow presets
from nlsq import fit

# Presets: 'fast', 'robust', 'global', 'quality', 'memory_efficient'
popt, pcov = fit(model, x, y, preset="robust")
Numerical stability
from nlsq import curve_fit

# Auto-detect and fix numerical issues
popt, pcov = curve_fit(model, x, y, p0=p0, stability="auto")
Memory management
from nlsq import MemoryConfig, memory_context, CurveFit

config = MemoryConfig(memory_limit_gb=8.0, enable_mixed_precision_fallback=True)

with memory_context(config):
    cf = CurveFit()
    popt, pcov = cf.curve_fit(model, x, y, p0=p0)

See Advanced Features for complete documentation.

Examples

Start with the Interactive Tutorial on Google Colab.

By topic:

See examples/README.md for the full index.

Requirements

  • Python 3.12+
  • JAX 0.8.0 (locked version)
  • NumPy 2.0+
  • SciPy 1.14.0+

GPU support (Linux only): CUDA 12.1-12.9, NVIDIA driver >= 525

Citation

If you use NLSQ in your research, please cite:

@software{nlsq2024,
  title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
  author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  year={2024},
  url={https://github.com/imewei/NLSQ}
}

@article{jaxfit2022,
  title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
  author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  journal={arXiv preprint arXiv:2208.12187},
  year={2022}
}

Acknowledgments

NLSQ is an enhanced fork of JAXFit by Lucas R. Hofer, Milan Krstajić, and Robert P. Smith. We gratefully acknowledge their foundational work.

License

MIT License - see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlsq-0.4.0.tar.gz (1.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nlsq-0.4.0-py3-none-any.whl (347.2 kB view details)

Uploaded Python 3

File details

Details for the file nlsq-0.4.0.tar.gz.

File metadata

  • Download URL: nlsq-0.4.0.tar.gz
  • Upload date:
  • Size: 1.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.4.0.tar.gz
Algorithm Hash digest
SHA256 1f67dd4888392c27112e5a62bad7689116c6167c830307c97c336e12ddd9dd3a
MD5 9eb9c4acbb00af57259c05f6e488dadd
BLAKE2b-256 ee48b5928d513cdee43c9563f74ff96dff52786be19e76047a14b835ac76e08a

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.4.0.tar.gz:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nlsq-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: nlsq-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 347.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8f7c0144f54ebb86854e9ac3ece50205a9b104b26a97067b7b280f3ca94a2d5b
MD5 d2e48d3b3e666151b54ad041e5fbe733
BLAKE2b-256 2dbd39c46fa72fcd09c591b2dbd0cda206528beb518537d4514a736901c4fb4c

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.4.0-py3-none-any.whl:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page