Skip to main content

GPU/TPU accelerated nonlinear least-squares curve fitting using JAX

Project description

NLSQ logo

NLSQ: GPU-Accelerated Curve Fitting

Drop-in replacement for scipy.optimize.curve_fit with 150-270x speedup on GPU

PyPI version Documentation Python 3.12+ JAX 0.8.0 License: MIT

Documentation | Examples | API Reference | ArXiv Paper


What is NLSQ?

NLSQ is a nonlinear least squares curve fitting library built on JAX. It provides:

  • SciPy-compatible API - Same function signatures as scipy.optimize.curve_fit
  • GPU/TPU acceleration - JIT-compiled kernels via XLA
  • Automatic differentiation - No manual Jacobian calculations needed
  • Large dataset support - Handles 100M+ data points with streaming optimization

Installation

# CPU (all platforms)
pip install nlsq

# GPU (Linux with CUDA 12.1+)
pip install nlsq
pip install "jax[cuda12-local]==0.8.0"
Verify GPU installation
python -c "import jax; print('Devices:', jax.devices())"
# Expected: [cuda(id=0)] for GPU, [CpuDevice(id=0)] for CPU

Quick Start

import numpy as np
import jax.numpy as jnp
from nlsq import curve_fit


# Define model function (use jax.numpy for GPU acceleration)
def exponential(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# Generate data
x = np.linspace(0, 4, 1000)
y = 2.5 * np.exp(-0.5 * x) + 1.0 + 0.1 * np.random.randn(len(x))

# Fit - same API as scipy.optimize.curve_fit
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5, 1.0])

print(f"Parameters: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}")
print(f"Uncertainties: {np.sqrt(np.diag(pcov))}")

Key Features

Feature Description
Automatic Jacobians JAX's autodiff eliminates manual derivatives
Bounded optimization Trust Region Reflective and Levenberg-Marquardt
Large datasets Chunked and streaming optimizers for 100M+ points
Multi-start Global optimization with LHS/Sobol sampling
Mixed precision Automatic float32→float64 upgrade when needed
Workflow system Auto-selects strategy based on dataset size
CLI interface YAML-based workflows with nlsq fit and nlsq batch

Performance

NLSQ shows increasing speedups as dataset size grows:

Dataset Size SciPy (CPU) NLSQ (GPU) Speedup
10K points 15ms 2ms 7x
100K points 150ms 3ms 50x
1M points 1.5s 8ms 190x
10M points 15s 55ms 270x

Note: First call includes JIT compilation (~500ms). Subsequent calls reuse compiled kernels.

See Performance Guide for benchmarks.

Large Dataset Example

from nlsq import fit
import numpy as np
import jax.numpy as jnp


def model(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# 50 million points
x = np.linspace(0, 10, 50_000_000)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))

# Auto-selects optimal strategy (chunked/streaming) based on memory
popt, pcov = fit(model, x, y, p0=[2.5, 0.6, 0.2], show_progress=True)

Advanced Usage

Multi-start global optimization
from nlsq import curve_fit

popt, pcov = curve_fit(
    model,
    x,
    y,
    p0=[1, 1, 1],
    bounds=([0, 0, 0], [10, 5, 10]),
    multistart=True,
    n_starts=20,
    sampler="lhs",
)
Workflow presets
from nlsq import fit

# Presets: 'fast', 'robust', 'global', 'quality', 'memory_efficient'
popt, pcov = fit(model, x, y, preset="robust")
Numerical stability
from nlsq import curve_fit

# Auto-detect and fix numerical issues
popt, pcov = curve_fit(model, x, y, p0=p0, stability="auto")
Memory management
from nlsq import MemoryConfig, memory_context, CurveFit

config = MemoryConfig(memory_limit_gb=8.0, enable_mixed_precision_fallback=True)

with memory_context(config):
    cf = CurveFit()
    popt, pcov = cf.curve_fit(model, x, y, p0=p0)
Command-line interface
# Single workflow
nlsq fit experiment.yaml

# Batch processing
nlsq batch configs/*.yaml --summary results.json

# System info
nlsq info

See CLI Reference for YAML configuration.

See Advanced Features for complete documentation.

Examples

Start with the Interactive Tutorial on Google Colab.

By topic:

See examples/README.md for the full index.

Requirements

  • Python 3.12+
  • JAX 0.8.0 (locked version)
  • NumPy 2.0+
  • SciPy 1.14.0+

GPU support (Linux only): CUDA 12.1-12.9, NVIDIA driver >= 525

Citation

If you use NLSQ in your research, please cite:

@software{nlsq2024,
  title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
  author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  year={2024},
  url={https://github.com/imewei/NLSQ}
}

@article{jaxfit2022,
  title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
  author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  journal={arXiv preprint arXiv:2208.12187},
  year={2022}
}

Acknowledgments

NLSQ is an enhanced fork of JAXFit by Lucas R. Hofer, Milan Krstajić, and Robert P. Smith. We gratefully acknowledge their foundational work.

License

MIT License - see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlsq-0.4.1.tar.gz (1.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nlsq-0.4.1-py3-none-any.whl (411.4 kB view details)

Uploaded Python 3

File details

Details for the file nlsq-0.4.1.tar.gz.

File metadata

  • Download URL: nlsq-0.4.1.tar.gz
  • Upload date:
  • Size: 1.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.4.1.tar.gz
Algorithm Hash digest
SHA256 6d27fbdb90c2f4227dadfb6bc0f690cecd64c9ec20653e4bffaa8b34b6a55c89
MD5 5b4072947fbc47da3a8f2265bfe4a26b
BLAKE2b-256 68c012e777f3371937ff19c19d766dc4dbf07b2bfe2165f4035fc0ffe1a58905

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.4.1.tar.gz:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nlsq-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: nlsq-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 411.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 96c07e18c0bc78755a6e0a9a8270a8a51c6cce976a141defa22b6852cca70553
MD5 bf6b03ccec96b256f94980a40c9cbb1f
BLAKE2b-256 1f7f8975b57792acbe51d3d0c1c6af5a35394635f3b834a1919b36e896ec0f73

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.4.1-py3-none-any.whl:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page