GPU/TPU accelerated nonlinear least-squares curve fitting using JAX
Project description
NLSQ: GPU-Accelerated Curve Fitting
Drop-in replacement for scipy.optimize.curve_fit with 150-270x speedup on GPU
What is NLSQ?
NLSQ is a nonlinear least squares curve fitting library built on JAX. It provides:
- SciPy-compatible API - Same function signatures as
scipy.optimize.curve_fit - GPU/TPU acceleration - JIT-compiled kernels via XLA
- Automatic differentiation - No manual Jacobian calculations needed
- Large dataset support - Handles 100M+ data points with streaming optimization
Installation
# CPU (all platforms)
pip install nlsq
# GPU (Linux with CUDA 12.1+)
pip install nlsq
pip install "jax[cuda12-local]==0.8.0"
Verify GPU installation
python -c "import jax; print('Devices:', jax.devices())"
# Expected: [cuda(id=0)] for GPU, [CpuDevice(id=0)] for CPU
Quick Start
import numpy as np
import jax.numpy as jnp
from nlsq import curve_fit
# Define model function (use jax.numpy for GPU acceleration)
def exponential(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Generate data
x = np.linspace(0, 4, 1000)
y = 2.5 * np.exp(-0.5 * x) + 1.0 + 0.1 * np.random.randn(len(x))
# Fit - same API as scipy.optimize.curve_fit
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5, 1.0])
print(f"Parameters: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}")
print(f"Uncertainties: {np.sqrt(np.diag(pcov))}")
Key Features
| Feature | Description |
|---|---|
| Automatic Jacobians | JAX's autodiff eliminates manual derivatives |
| Bounded optimization | Trust Region Reflective and Levenberg-Marquardt |
| Large datasets | Chunked and streaming optimizers for 100M+ points |
| Multi-start | Global optimization with LHS/Sobol sampling |
| Mixed precision | Automatic float32→float64 upgrade when needed |
| Workflow system | Auto-selects strategy based on dataset size |
| CLI interface | YAML-based workflows with nlsq fit and nlsq batch |
Performance
NLSQ shows increasing speedups as dataset size grows:
| Dataset Size | SciPy (CPU) | NLSQ (GPU) | Speedup |
|---|---|---|---|
| 10K points | 15ms | 2ms | 7x |
| 100K points | 150ms | 3ms | 50x |
| 1M points | 1.5s | 8ms | 190x |
| 10M points | 15s | 55ms | 270x |
Note: First call includes JIT compilation (~500ms). Subsequent calls reuse compiled kernels.
See Performance Guide for benchmarks.
Large Dataset Example
from nlsq import fit
import numpy as np
import jax.numpy as jnp
def model(x, a, b, c):
return a * jnp.exp(-b * x) + c
# 50 million points
x = np.linspace(0, 10, 50_000_000)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))
# Auto-selects optimal strategy (chunked/streaming) based on memory
popt, pcov = fit(model, x, y, p0=[2.5, 0.6, 0.2], show_progress=True)
Advanced Usage
Multi-start global optimization
from nlsq import curve_fit
popt, pcov = curve_fit(
model,
x,
y,
p0=[1, 1, 1],
bounds=([0, 0, 0], [10, 5, 10]),
multistart=True,
n_starts=20,
sampler="lhs",
)
Workflow presets
from nlsq import fit
# Presets: 'fast', 'robust', 'global', 'quality', 'memory_efficient'
popt, pcov = fit(model, x, y, preset="robust")
Numerical stability
from nlsq import curve_fit
# Auto-detect and fix numerical issues
popt, pcov = curve_fit(model, x, y, p0=p0, stability="auto")
Memory management
from nlsq import MemoryConfig, memory_context, CurveFit
config = MemoryConfig(memory_limit_gb=8.0, enable_mixed_precision_fallback=True)
with memory_context(config):
cf = CurveFit()
popt, pcov = cf.curve_fit(model, x, y, p0=p0)
Command-line interface
# Single workflow
nlsq fit experiment.yaml
# Batch processing
nlsq batch configs/*.yaml --summary results.json
# System info
nlsq info
See CLI Reference for YAML configuration.
See Advanced Features for complete documentation.
Examples
Start with the Interactive Tutorial on Google Colab.
By topic:
- Getting Started - Basic usage and quickstart
- Core Tutorials - Large datasets, bounded optimization
- Advanced - GPU optimization, streaming, checkpointing
- Applications - Physics, chemistry, biology, engineering
See examples/README.md for the full index.
Requirements
- Python 3.12+
- JAX 0.8.0 (locked version)
- NumPy 2.0+
- SciPy 1.14.0+
GPU support (Linux only): CUDA 12.1-12.9, NVIDIA driver >= 525
Citation
If you use NLSQ in your research, please cite:
@software{nlsq2024,
title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
year={2024},
url={https://github.com/imewei/NLSQ}
}
@article{jaxfit2022,
title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
journal={arXiv preprint arXiv:2208.12187},
year={2022}
}
Acknowledgments
NLSQ is an enhanced fork of JAXFit by Lucas R. Hofer, Milan Krstajić, and Robert P. Smith. We gratefully acknowledge their foundational work.
License
MIT License - see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nlsq-0.4.1.tar.gz.
File metadata
- Download URL: nlsq-0.4.1.tar.gz
- Upload date:
- Size: 1.9 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6d27fbdb90c2f4227dadfb6bc0f690cecd64c9ec20653e4bffaa8b34b6a55c89
|
|
| MD5 |
5b4072947fbc47da3a8f2265bfe4a26b
|
|
| BLAKE2b-256 |
68c012e777f3371937ff19c19d766dc4dbf07b2bfe2165f4035fc0ffe1a58905
|
Provenance
The following attestation bundles were made for nlsq-0.4.1.tar.gz:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.4.1.tar.gz -
Subject digest:
6d27fbdb90c2f4227dadfb6bc0f690cecd64c9ec20653e4bffaa8b34b6a55c89 - Sigstore transparency entry: 779266332
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@dc4d02d1344504e3bab8d3954aa444ea7e8b70ae -
Branch / Tag:
refs/tags/v0.4.1 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@dc4d02d1344504e3bab8d3954aa444ea7e8b70ae -
Trigger Event:
push
-
Statement type:
File details
Details for the file nlsq-0.4.1-py3-none-any.whl.
File metadata
- Download URL: nlsq-0.4.1-py3-none-any.whl
- Upload date:
- Size: 411.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
96c07e18c0bc78755a6e0a9a8270a8a51c6cce976a141defa22b6852cca70553
|
|
| MD5 |
bf6b03ccec96b256f94980a40c9cbb1f
|
|
| BLAKE2b-256 |
1f7f8975b57792acbe51d3d0c1c6af5a35394635f3b834a1919b36e896ec0f73
|
Provenance
The following attestation bundles were made for nlsq-0.4.1-py3-none-any.whl:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.4.1-py3-none-any.whl -
Subject digest:
96c07e18c0bc78755a6e0a9a8270a8a51c6cce976a141defa22b6852cca70553 - Sigstore transparency entry: 779266340
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@dc4d02d1344504e3bab8d3954aa444ea7e8b70ae -
Branch / Tag:
refs/tags/v0.4.1 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@dc4d02d1344504e3bab8d3954aa444ea7e8b70ae -
Trigger Event:
push
-
Statement type: