GPU/TPU accelerated nonlinear least-squares curve fitting using JAX
Project description
NLSQ: GPU-Accelerated Curve Fitting
Drop-in replacement for scipy.optimize.curve_fit with 150-270x speedup on GPU
What is NLSQ?
NLSQ is a nonlinear least squares curve fitting library built on JAX. It provides:
- SciPy-compatible API - Same function signatures as
scipy.optimize.curve_fit - GPU/TPU acceleration - JIT-compiled kernels via XLA
- Automatic differentiation - No manual Jacobian calculations needed
- Large dataset support - Handles 100M+ data points with streaming optimization
- Interactive GUI - Streamlit-based graphical interface for no-code curve fitting
Installation
Basic Installation (CPU-only)
pip install nlsq
This installs with CPU-only JAX (works on all platforms: Linux, macOS, Windows).
GPU Installation (Linux + System CUDA)
Performance Impact: 20-100x speedup for large datasets (>1M points)
Prerequisites:
- NVIDIA GPU with SM >= 5.2 (Maxwell or newer)
- System CUDA 12.x or 13.x installed
nvccin PATH
Verify Prerequisites
# Check CUDA installation
nvcc --version
# Should show: Cuda compilation tools, release 12.x or 13.x
# Check GPU
nvidia-smi --query-gpu=name,compute_cap --format=csv,noheader
# Should show: GPU name and SM version (e.g., "8.9" for RTX 4090)
Option 1: Quick Install via Makefile (Recommended)
git clone https://github.com/imewei/NLSQ.git
cd NLSQ
# Auto-detect system CUDA version and install matching JAX
make install-jax-gpu
# Or explicitly choose CUDA version:
make install-jax-gpu-cuda13 # Requires system CUDA 13.x + SM >= 7.5
make install-jax-gpu-cuda12 # Requires system CUDA 12.x + SM >= 5.2
Option 2: Manual Installation
# For System CUDA 13.x (Turing and newer GPUs):
pip uninstall -y jax jaxlib
pip install "jax[cuda13-local]"
# For System CUDA 12.x (Maxwell and newer GPUs):
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]"
# Verify
python -c "import jax; print('Backend:', jax.default_backend())"
# Should show: Backend: gpu
GPU Compatibility Guide
| GPU Generation | Example GPUs | SM Version | CUDA 13 | CUDA 12 |
|---|---|---|---|---|
| Blackwell | B100, B200 | 10.0 | Yes | Yes |
| Hopper | H100, H200 | 9.0 | Yes | Yes |
| Ada Lovelace | RTX 40xx, L40 | 8.9 | Yes | Yes |
| Ampere | RTX 30xx, A100 | 8.x | Yes | Yes |
| Turing | RTX 20xx, T4 | 7.5 | Yes | Yes |
| Volta | V100, Titan V | 7.0 | No | Yes |
| Pascal | GTX 10xx, P100 | 6.x | No | Yes |
| Maxwell | GTX 9xx, Titan X | 5.x | No | Yes |
| Kepler | GTX 7xx, K80 | 3.x | No | No |
Recommendation: SM >= 7.5 (RTX 20xx or newer): Install CUDA 13 for best performance
GPU Troubleshooting
Issue: "nvcc not found"
# Option 1: Install CUDA toolkit
sudo apt install nvidia-cuda-toolkit # Ubuntu/Debian
# Option 2: Add existing CUDA to PATH
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
Issue: "CUDA version mismatch"
# Check your system CUDA version
nvcc --version
# Reinstall with correct package
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]" # or cuda13-local
With GUI Support
pip install nlsq[gui]
Quick Start
import numpy as np
import jax.numpy as jnp
from nlsq import curve_fit
# Define model function (use jax.numpy for GPU acceleration)
def exponential(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Generate data
x = np.linspace(0, 4, 1000)
y = 2.5 * np.exp(-0.5 * x) + 1.0 + 0.1 * np.random.randn(len(x))
# Fit - same API as scipy.optimize.curve_fit
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5, 1.0])
print(f"Parameters: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}")
print(f"Uncertainties: {np.sqrt(np.diag(pcov))}")
Graphical User Interface
NLSQ includes an interactive GUI for curve fitting without writing code:
# Launch the GUI
nlsq gui
The GUI provides:
- Data Loading: Import CSV, ASCII, NPZ, or HDF5 files, or paste from clipboard
- Model Selection: Choose from 7 built-in models, polynomials, or custom Python functions
- Fitting Options: Guided presets (Fast/Robust/Quality) or advanced parameter control
- Interactive Results: Plotly-based visualizations with confidence bands and residuals
- Export: Session bundles (ZIP), JSON/CSV results, and reproducible Python code
See the GUI User Guide for detailed documentation.
Key Features
| Feature | Description |
|---|---|
| Automatic Jacobians | JAX's autodiff eliminates manual derivatives |
| Bounded optimization | Trust Region Reflective and Levenberg-Marquardt |
| Large datasets | Chunked and streaming optimizers for 100M+ points |
| Multi-start | Global optimization with LHS/Sobol sampling |
| Mixed precision | Automatic float32→float64 upgrade when needed |
| Workflow system | Auto-selects strategy based on dataset size |
| CLI interface | YAML-based workflows with nlsq fit and nlsq batch |
| Interactive GUI | No-code curve fitting with Streamlit interface |
Performance
NLSQ shows increasing speedups as dataset size grows:
| Dataset Size | SciPy (CPU) | NLSQ (GPU) | Speedup |
|---|---|---|---|
| 10K points | 15ms | 2ms | 7x |
| 100K points | 150ms | 3ms | 50x |
| 1M points | 1.5s | 8ms | 190x |
| 10M points | 15s | 55ms | 270x |
Note: First call includes JIT compilation (~500ms). Subsequent calls reuse compiled kernels.
See Performance Guide for benchmarks.
Large Dataset Example
from nlsq import fit
import numpy as np
import jax.numpy as jnp
def model(x, a, b, c):
return a * jnp.exp(-b * x) + c
# 50 million points
x = np.linspace(0, 10, 50_000_000)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))
# Auto-selects optimal strategy (chunked/streaming) based on memory
popt, pcov = fit(model, x, y, p0=[2.5, 0.6, 0.2], show_progress=True)
Advanced Usage
Multi-start global optimization
from nlsq import curve_fit
popt, pcov = curve_fit(
model,
x,
y,
p0=[1, 1, 1],
bounds=([0, 0, 0], [10, 5, 10]),
multistart=True,
n_starts=20,
sampler="lhs",
)
Workflow presets
from nlsq import fit
# Presets: 'fast', 'robust', 'global', 'quality', 'memory_efficient'
popt, pcov = fit(model, x, y, preset="robust")
Numerical stability
from nlsq import curve_fit
# Auto-detect and fix numerical issues
popt, pcov = curve_fit(model, x, y, p0=p0, stability="auto")
Memory management
from nlsq import MemoryConfig, memory_context, CurveFit
config = MemoryConfig(memory_limit_gb=8.0, enable_mixed_precision_fallback=True)
with memory_context(config):
cf = CurveFit()
popt, pcov = cf.curve_fit(model, x, y, p0=p0)
Command-line interface
# Launch GUI
nlsq gui
# Single workflow
nlsq fit experiment.yaml
# Batch processing
nlsq batch configs/*.yaml --summary results.json
# System info
nlsq info
See CLI Reference for YAML configuration.
See Advanced Features for complete documentation.
Examples
Start with the Interactive Tutorial on Google Colab.
By topic:
- Getting Started - Basic usage and quickstart
- Core Tutorials - Large datasets, bounded optimization
- Advanced - GPU optimization, streaming, checkpointing
- Applications - Physics, chemistry, biology, engineering
See examples/README.md for the full index.
Requirements
- Python 3.12+
- JAX 0.8.0 (locked version)
- NumPy 2.0+
- SciPy 1.14.0+
GUI requirements: Streamlit 1.52+, Plotly 6.5+
GPU support (Linux only): CUDA 12.1-12.9, NVIDIA driver >= 525
Citation
If you use NLSQ in your research, please cite:
@software{nlsq2025,
title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
year={2025},
url={https://github.com/imewei/NLSQ}
}
@article{jaxfit2022,
title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
journal={arXiv preprint arXiv:2208.12187},
year={2022}
}
Acknowledgments
NLSQ is an enhanced fork of JAXFit by Lucas R. Hofer, Milan Krstajic, and Robert P. Smith. We gratefully acknowledge their foundational work.
License
MIT License - see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nlsq-0.4.2.tar.gz.
File metadata
- Download URL: nlsq-0.4.2.tar.gz
- Upload date:
- Size: 2.3 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7d18f38e17cebe8c7948919b0fffc9131c65ec6d5765fcd83aae1de6ff9594a3
|
|
| MD5 |
244e768a645924f041aee12807245dbf
|
|
| BLAKE2b-256 |
26682e3ae3c7bafc416303038061b86d6cc947a446879b164798210bb115d6a9
|
Provenance
The following attestation bundles were made for nlsq-0.4.2.tar.gz:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.4.2.tar.gz -
Subject digest:
7d18f38e17cebe8c7948919b0fffc9131c65ec6d5765fcd83aae1de6ff9594a3 - Sigstore transparency entry: 782132756
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@66a9dd294512aebba1daf323136e2b433b63e414 -
Branch / Tag:
refs/tags/v0.4.2 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@66a9dd294512aebba1daf323136e2b433b63e414 -
Trigger Event:
push
-
Statement type:
File details
Details for the file nlsq-0.4.2-py3-none-any.whl.
File metadata
- Download URL: nlsq-0.4.2-py3-none-any.whl
- Upload date:
- Size: 571.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
21d8eca811b9b7b0fbc14130c866f3577837161f47e61ff3957dcb84e3179ff9
|
|
| MD5 |
b0cff31e695a505787dafa37b51edb8f
|
|
| BLAKE2b-256 |
85dd77843b21569884f86cb1ff55f08b237fcfbfb7f0b55a63201a8ec1ef619a
|
Provenance
The following attestation bundles were made for nlsq-0.4.2-py3-none-any.whl:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.4.2-py3-none-any.whl -
Subject digest:
21d8eca811b9b7b0fbc14130c866f3577837161f47e61ff3957dcb84e3179ff9 - Sigstore transparency entry: 782132757
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@66a9dd294512aebba1daf323136e2b433b63e414 -
Branch / Tag:
refs/tags/v0.4.2 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@66a9dd294512aebba1daf323136e2b433b63e414 -
Trigger Event:
push
-
Statement type: