Skip to main content

GPU/TPU accelerated nonlinear least-squares curve fitting using JAX

Project description

NLSQ logo

NLSQ: GPU-Accelerated Curve Fitting

Drop-in replacement for scipy.optimize.curve_fit with 150-270x speedup on GPU

PyPI version Documentation Python 3.12+ JAX 0.8.0 License: MIT

Documentation | Examples | API Reference | ArXiv Paper


What is NLSQ?

NLSQ is a nonlinear least squares curve fitting library built on JAX. It provides:

  • SciPy-compatible API - Same function signatures as scipy.optimize.curve_fit
  • GPU/TPU acceleration - JIT-compiled kernels via XLA
  • Automatic differentiation - No manual Jacobian calculations needed
  • Large dataset support - Handles 100M+ data points with streaming optimization

Installation

# CPU (all platforms)
pip install nlsq

# GPU (Linux with CUDA 12.1+)
pip install nlsq
pip install "jax[cuda12-local]==0.8.0"
Verify GPU installation
python -c "import jax; print('Devices:', jax.devices())"
# Expected: [cuda(id=0)] for GPU, [CpuDevice(id=0)] for CPU

Quick Start

import numpy as np
import jax.numpy as jnp
from nlsq import curve_fit


# Define model function (use jax.numpy for GPU acceleration)
def exponential(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# Generate data
x = np.linspace(0, 4, 1000)
y = 2.5 * np.exp(-0.5 * x) + 1.0 + 0.1 * np.random.randn(len(x))

# Fit - same API as scipy.optimize.curve_fit
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5, 1.0])

print(f"Parameters: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}")
print(f"Uncertainties: {np.sqrt(np.diag(pcov))}")

Key Features

Feature Description
Automatic Jacobians JAX's autodiff eliminates manual derivatives
Bounded optimization Trust Region Reflective and Levenberg-Marquardt
Large datasets Chunked and streaming optimizers for 100M+ points
Multi-start Global optimization with LHS/Sobol sampling
Mixed precision Automatic float32→float64 upgrade when needed
Workflow system Auto-selects strategy based on dataset size

Performance

NLSQ shows increasing speedups as dataset size grows:

Dataset Size SciPy (CPU) NLSQ (GPU) Speedup
10K points 15ms 2ms 7x
100K points 150ms 3ms 50x
1M points 1.5s 8ms 190x
10M points 15s 55ms 270x

Note: First call includes JIT compilation (~500ms). Subsequent calls reuse compiled kernels.

See Performance Guide for benchmarks.

Large Dataset Example

from nlsq import fit
import numpy as np
import jax.numpy as jnp


def model(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# 50 million points
x = np.linspace(0, 10, 50_000_000)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))

# Auto-selects optimal strategy (chunked/streaming) based on memory
popt, pcov = fit(model, x, y, p0=[2.5, 0.6, 0.2], show_progress=True)

Advanced Usage

Multi-start global optimization
from nlsq import curve_fit

popt, pcov = curve_fit(
    model,
    x,
    y,
    p0=[1, 1, 1],
    bounds=([0, 0, 0], [10, 5, 10]),
    multistart=True,
    n_starts=20,
    sampler="lhs",
)
Workflow presets
from nlsq import fit

# Presets: 'fast', 'robust', 'global', 'quality', 'memory_efficient'
popt, pcov = fit(model, x, y, preset="robust")
Numerical stability
from nlsq import curve_fit

# Auto-detect and fix numerical issues
popt, pcov = curve_fit(model, x, y, p0=p0, stability="auto")
Memory management
from nlsq import MemoryConfig, memory_context, CurveFit

config = MemoryConfig(memory_limit_gb=8.0, enable_mixed_precision_fallback=True)

with memory_context(config):
    cf = CurveFit()
    popt, pcov = cf.curve_fit(model, x, y, p0=p0)

See Advanced Features for complete documentation.

Examples

Start with the Interactive Tutorial on Google Colab.

By topic:

See examples/README.md for the full index.

Requirements

  • Python 3.12+
  • JAX 0.8.0 (locked version)
  • NumPy 2.0+
  • SciPy 1.14.0+

GPU support (Linux only): CUDA 12.1-12.9, NVIDIA driver >= 525

Citation

If you use NLSQ in your research, please cite:

@software{nlsq2024,
  title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
  author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  year={2024},
  url={https://github.com/imewei/NLSQ}
}

@article{jaxfit2022,
  title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
  author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  journal={arXiv preprint arXiv:2208.12187},
  year={2022}
}

Acknowledgments

NLSQ is an enhanced fork of JAXFit by Lucas R. Hofer, Milan Krstajić, and Robert P. Smith. We gratefully acknowledge their foundational work.

License

MIT License - see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlsq-0.3.9.tar.gz (1.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nlsq-0.3.9-py3-none-any.whl (347.1 kB view details)

Uploaded Python 3

File details

Details for the file nlsq-0.3.9.tar.gz.

File metadata

  • Download URL: nlsq-0.3.9.tar.gz
  • Upload date:
  • Size: 1.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.3.9.tar.gz
Algorithm Hash digest
SHA256 77e15710aab85d271fac284303554f253073ba86bf3d841f7cd65d57968829d5
MD5 2557b1f1ae3841acb49f20a0faa561db
BLAKE2b-256 4f5a9b1b9034cea26c45bbf1ab8f8b1e55917d93e59609aadaec8690147c0455

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.3.9.tar.gz:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nlsq-0.3.9-py3-none-any.whl.

File metadata

  • Download URL: nlsq-0.3.9-py3-none-any.whl
  • Upload date:
  • Size: 347.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.3.9-py3-none-any.whl
Algorithm Hash digest
SHA256 f03bae84b57d4b74e6fb0aaf682dd187e3abadebfb7c283677b1949acfdf9596
MD5 2e0eea2002ca5ee9377e62ee9d69c735
BLAKE2b-256 e0b17029ea0fd7253658a2f6201244f3aa833d47caaa140a853a3b06152c8bf4

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.3.9-py3-none-any.whl:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page