GPU/TPU accelerated nonlinear least-squares curve fitting using JAX
Project description
NLSQ: GPU-Accelerated Curve Fitting
Drop-in replacement for scipy.optimize.curve_fit with 150-270x speedup on GPU
What is NLSQ?
NLSQ is a nonlinear least squares curve fitting library built on JAX. It provides:
- SciPy-compatible API - Same function signatures as
scipy.optimize.curve_fit - GPU/TPU acceleration - JIT-compiled kernels via XLA
- Automatic differentiation - No manual Jacobian calculations needed
- Large dataset support - Handles 100M+ data points with streaming optimization
- Interactive GUI - Native Qt desktop application for no-code curve fitting
Installation
Basic Installation (CPU-only)
pip install nlsq
This installs with CPU-only JAX (works on all platforms: Linux, macOS, Windows).
GPU Installation (Linux + System CUDA)
Performance Impact: 20-100x speedup for large datasets (>1M points)
Prerequisites:
- NVIDIA GPU with SM >= 5.2 (Maxwell or newer)
- System CUDA 12.x or 13.x installed
nvccin PATH
Verify Prerequisites
# Check CUDA installation
nvcc --version
# Should show: Cuda compilation tools, release 12.x or 13.x
# Check GPU
nvidia-smi --query-gpu=name,compute_cap --format=csv,noheader
# Should show: GPU name and SM version (e.g., "8.9" for RTX 4090)
Option 1: Quick Install via Makefile (Recommended)
git clone https://github.com/imewei/NLSQ.git
cd NLSQ
# Auto-detect system CUDA version and install matching JAX
make install-jax-gpu
# Or explicitly choose CUDA version:
make install-jax-gpu-cuda13 # Requires system CUDA 13.x + SM >= 7.5
make install-jax-gpu-cuda12 # Requires system CUDA 12.x + SM >= 5.2
Option 2: Manual Installation
# For System CUDA 13.x (Turing and newer GPUs):
pip uninstall -y jax jaxlib
pip install "jax[cuda13-local]"
# For System CUDA 12.x (Maxwell and newer GPUs):
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]"
# Verify
python -c "import jax; print('Backend:', jax.default_backend())"
# Should show: Backend: gpu
GPU Compatibility Guide
| GPU Generation | Example GPUs | SM Version | CUDA 13 | CUDA 12 |
|---|---|---|---|---|
| Blackwell | B100, B200 | 10.0 | Yes | Yes |
| Hopper | H100, H200 | 9.0 | Yes | Yes |
| Ada Lovelace | RTX 40xx, L40 | 8.9 | Yes | Yes |
| Ampere | RTX 30xx, A100 | 8.x | Yes | Yes |
| Turing | RTX 20xx, T4 | 7.5 | Yes | Yes |
| Volta | V100, Titan V | 7.0 | No | Yes |
| Pascal | GTX 10xx, P100 | 6.x | No | Yes |
| Maxwell | GTX 9xx, Titan X | 5.x | No | Yes |
| Kepler | GTX 7xx, K80 | 3.x | No | No |
Recommendation: SM >= 7.5 (RTX 20xx or newer): Install CUDA 13 for best performance
GPU Troubleshooting
Issue: "nvcc not found"
# Option 1: Install CUDA toolkit
sudo apt install nvidia-cuda-toolkit # Ubuntu/Debian
# Option 2: Add existing CUDA to PATH
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
Issue: "CUDA version mismatch"
# Check your system CUDA version
nvcc --version
# Reinstall with correct package
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]" # or cuda13-local
With GUI Support
GUI dependencies (PySide6, pyqtgraph) are included in the main package:
pip install nlsq # GUI included
nlsq-gui # Launch the desktop application
Quick Start
import numpy as np
import jax.numpy as jnp
from nlsq import curve_fit
# Define model function (use jax.numpy for GPU acceleration)
def exponential(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Generate data
x = np.linspace(0, 4, 1000)
y = 2.5 * np.exp(-0.5 * x) + 1.0 + 0.1 * np.random.randn(len(x))
# Fit - same API as scipy.optimize.curve_fit
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5, 1.0])
print(f"Parameters: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}")
print(f"Uncertainties: {np.sqrt(np.diag(pcov))}")
Graphical User Interface
NLSQ includes an interactive GUI for curve fitting without writing code:
# Launch the GUI
nlsq-gui
The GUI provides:
- Data Loading: Import CSV, ASCII, NPZ, or HDF5 files, or paste from clipboard
- Model Selection: Choose from 8 built-in models, polynomials, or custom Python functions
- Fitting Options: Guided presets (Fast/Robust/Quality) or advanced parameter control
- Interactive Results: GPU-accelerated pyqtgraph plots with confidence bands and residuals
- Export: Session bundles (ZIP), JSON/CSV results, and reproducible Python code
See the GUI User Guide for detailed documentation.
Key Features
| Feature | Description |
|---|---|
| Automatic Jacobians | JAX's autodiff eliminates manual derivatives |
| Bounded optimization | Trust Region Reflective and Levenberg-Marquardt |
| Large datasets | Chunked and streaming optimizers for 100M+ points |
| Multi-start | Global optimization with LHS/Sobol sampling |
| Workflow system | 3 smart workflows: auto, auto_global, hpc with memory-aware strategy |
| CLI interface | YAML-based workflows with nlsq fit and nlsq batch |
| Interactive GUI | No-code curve fitting with Qt desktop application |
| Model Diagnostics | Identifiability analysis, gradient health monitoring, sloppy model detection |
Architecture
NLSQ is organized into well-separated layers (~74,000 lines):
┌─────────────────────────────────────────────────────────────────────────────┐
│ USER INTERFACES │
│ Qt GUI (PySide6) CLI (Click) Python API │
│ ├── 5-page workflow ├── Model validation ├── curve_fit(), fit() │
│ ├── pyqtgraph plots ├── Security auditing ├── CurveFit class │
│ └── Native desktop └── Export formats └── LargeDatasetFitter │
├─────────────────────────────────────────────────────────────────────────────┤
│ OPTIMIZATION ORCHESTRATION │
│ Workflow System Global Optimization Streaming Optimizer │
│ ├── MemoryBudgetSel. ├── MultiStartOrch. ├── AdaptiveHybrid │
│ ├── Strategy: STANDARD/ ├── TournamentSelect ├── 4-Phase Pipeline: │
│ │ CHUNKED/STREAMING ├── LHS/Sobol/Halton │ 0: Normalization │
│ └── Memory-based auto └── Sampling │ 1: L-BFGS warmup │
│ │ 2: Gauss-Newton │
│ Orchestration (v0.6.4) └── 3: Denormalization │
│ ├── DataPreprocessor ├── OptimizationSel. │
│ ├── CovarianceComputer └── StreamingCoord. │
├─────────────────────────────────────────────────────────────────────────────┤
│ CORE OPTIMIZATION ENGINE │
│ curve_fit() ─→ CurveFit ─→ LeastSquares ─→ TrustRegionReflective │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ API Wrapper Cache+State Orchestrator+AD SVD-based TRF │
│ (SciPy-compat) (UnifiedCache) (AutoDiffJac) (JIT-compiled) │
├─────────────────────────────────────────────────────────────────────────────┤
│ SUPPORT SUBSYSTEMS │
│ stability/ precision/ caching/ diagnostics/ │
│ ├── NumericalGuard ├── AlgorithmSel. ├── UnifiedCache ├── Identifiab. │
│ ├── SVD fallback ├── ParamNormalizer ├── SmartCache ├── GradientMon │
│ └── Recovery └── BoundsInfer. └── MemoryMgr └── PluginSys. │
│ facades/ (v0.6.4) │
│ ├── OptimizationFacade ├── StabilityFacade └── DiagnosticsFacade │
├─────────────────────────────────────────────────────────────────────────────┤
│ INFRASTRUCTURE │
│ interfaces/ (Protocols) config.py (Singleton) Security │
│ ├── OptimizerProtocol ├── JAXConfig ├── safe_serialize │
│ ├── CurveFitProtocol ├── MemoryConfig ├── model_validation │
│ └── CacheProtocol └── LargeDatasetConfig └── resource_limits │
├─────────────────────────────────────────────────────────────────────────────┤
│ JAX RUNTIME (≥0.8.0) │
│ x64 enabled │ JIT compilation │ Autodiff │ GPU/TPU backend (optional) │
└─────────────────────────────────────────────────────────────────────────────┘
See Architecture Documentation for detailed design.
Performance
NLSQ shows increasing speedups as dataset size grows:
| Dataset Size | SciPy (CPU) | NLSQ (GPU) | Speedup |
|---|---|---|---|
| 10K points | 15ms | 2ms | 7x |
| 100K points | 150ms | 3ms | 50x |
| 1M points | 1.5s | 8ms | 190x |
| 10M points | 15s | 55ms | 270x |
Note: First call includes JIT compilation (~500ms). Subsequent calls reuse compiled kernels.
See Performance Guide for benchmarks.
Large Dataset Example
from nlsq import fit
import numpy as np
import jax.numpy as jnp
def model(x, a, b, c):
return a * jnp.exp(-b * x) + c
# 50 million points
x = np.linspace(0, 10, 50_000_000)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))
# Auto-selects optimal strategy (chunked/streaming) based on memory
popt, pcov = fit(model, x, y, p0=[2.5, 0.6, 0.2], show_progress=True)
Advanced Usage
Multi-start global optimization
from nlsq import curve_fit
popt, pcov = curve_fit(
model,
x,
y,
p0=[1, 1, 1],
bounds=([0, 0, 0], [10, 5, 10]),
multistart=True,
n_starts=20,
sampler="lhs",
)
CMA-ES global optimization
For multi-scale parameter problems (parameters spanning >1000x scale ratio):
from nlsq.global_optimization import CMAESOptimizer, CMAESConfig
# Basic usage with BIPOP restart strategy
optimizer = CMAESOptimizer()
result = optimizer.fit(model, x, y, bounds=bounds)
# Memory-efficient configuration for large datasets (>10M points)
config = CMAESConfig(
population_batch_size=4, # Batch population evaluation
data_chunk_size=50000, # Stream data in chunks
)
optimizer = CMAESOptimizer(config=config)
Requires: pip install nlsq
Workflow presets (v0.6.3)
NLSQ v0.6.3 simplifies workflows to 3 smart options:
from nlsq import fit
# workflow="auto" (default) - memory-aware local optimization
result = fit(model, x, y, p0=[1, 1, 1])
# workflow="auto_global" - memory-aware global optimization (requires bounds)
result = fit(
model, x, y, p0=[1, 1, 1], workflow="auto_global", bounds=([0, 0, 0], [10, 5, 10])
)
# workflow="hpc" - auto_global + checkpointing for HPC jobs
result = fit(
model,
x,
y,
p0=[1, 1, 1],
workflow="hpc",
bounds=([0, 0, 0], [10, 5, 10]),
checkpoint_dir="/scratch/checkpoints",
)
| Workflow | Bounds | Use Case |
|---|---|---|
auto |
Optional | Default. Local optimization with auto memory strategy |
auto_global |
Required | Multi-modal problems, unknown initial guess |
hpc |
Required | Long-running HPC jobs with checkpointing |
Memory-based strategy selection
from nlsq.core.workflow import MemoryBudget, MemoryBudgetSelector
# Compute memory budget for your dataset
budget = MemoryBudget.compute(n_points=10_000_000, n_params=10)
print(f"Peak memory: {budget.peak_gb:.2f} GB")
print(f"Fits in memory: {budget.fits_in_memory}")
# Let selector choose strategy
selector = MemoryBudgetSelector(safety_factor=0.75)
strategy, config = selector.select(
n_points=10_000_000,
n_params=10,
memory_limit_gb=16.0, # Optional override
)
print(f"Selected: {strategy}") # "streaming", "chunked", or "standard"
Numerical stability
from nlsq import curve_fit
# Auto-detect and fix numerical issues
popt, pcov = curve_fit(model, x, y, p0=p0, stability="auto")
Memory management
from nlsq import MemoryConfig, memory_context, CurveFit
config = MemoryConfig(memory_limit_gb=8.0)
with memory_context(config):
cf = CurveFit()
popt, pcov = cf.curve_fit(model, x, y, p0=p0)
Command-line interface
# Launch GUI
nlsq gui
# Copy configuration templates
nlsq config
# Single workflow
nlsq fit experiment.yaml
# Output results to stdout (for piping)
nlsq fit experiment.yaml --stdout | jq '.popt'
# Batch processing
nlsq batch configs/*.yaml --summary results.json --workers 4
# System info
nlsq info
See CLI Reference for full documentation.
Model health diagnostics
from nlsq.diagnostics import (
DiagnosticsConfig,
IdentifiabilityAnalyzer,
GradientMonitor,
ParameterSensitivityAnalyzer,
create_health_report,
)
# Analyze parameter identifiability from Jacobian
config = DiagnosticsConfig(condition_threshold=1e8, correlation_threshold=0.95)
analyzer = IdentifiabilityAnalyzer(config)
report = analyzer.analyze(result.jac)
print(f"Condition number: {report.condition_number:.2e}")
print(f"Numerical rank: {report.numerical_rank}/{report.n_params}")
# Monitor gradient health during optimization
monitor = GradientMonitor(config)
callback = monitor.create_callback()
result = curve_fit(model, x, y, p0=p0, callback=callback)
grad_report = monitor.get_report()
# Analyze parameter sensitivity spectrum
sensitivity_analyzer = ParameterSensitivityAnalyzer(config)
sensitivity_report = sensitivity_analyzer.analyze(result.jac)
print(f"Is sloppy: {sensitivity_report.is_sloppy}")
print(f"Effective dimensionality: {sensitivity_report.effective_dimensionality:.1f}")
# Create aggregated health report
health_report = create_health_report(
identifiability=report,
gradient_health=grad_report,
sloppy_model=sensitivity_report,
)
print(health_report.summary())
See Diagnostics API for complete documentation.
See Advanced User Guide for complete documentation.
Examples
Start with the Interactive Tutorial on Google Colab.
By topic:
- Getting Started - Basic usage and quickstart
- Core Tutorials - Large datasets, bounded optimization
- Advanced - GPU optimization, streaming, checkpointing
- Applications - Physics, chemistry, biology, engineering
See examples/README.md for the full index.
Requirements
- Python 3.12+
- JAX 0.8.0+
- NumPy 2.2+
- SciPy 1.16.0+
GUI requirements: PySide6 6.10.0+, pyqtgraph 0.14.0+
GPU support (Linux only): CUDA 12.x-13.x, NVIDIA driver >= 525
Citation
If you use NLSQ in your research, please cite:
@software{nlsq2025,
title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
year={2025},
url={https://github.com/imewei/NLSQ}
}
@article{jaxfit2022,
title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
journal={arXiv preprint arXiv:2208.12187},
year={2022}
}
Acknowledgments
NLSQ is an enhanced fork of JAXFit by Lucas R. Hofer, Milan Krstajic, and Robert P. Smith. We gratefully acknowledge their foundational work.
License
MIT License - see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nlsq-0.6.11.tar.gz.
File metadata
- Download URL: nlsq-0.6.11.tar.gz
- Upload date:
- Size: 2.1 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
38555a6c3ab10e2fd7a4d22814130f605573676893d8239787c1e4ef6284790f
|
|
| MD5 |
98f5cb0939c0ae68081b0227a161f7ff
|
|
| BLAKE2b-256 |
b9c78828779c5d951950570e3f774cf35a6df05f58d6f22c4f462c00e393fcda
|
Provenance
The following attestation bundles were made for nlsq-0.6.11.tar.gz:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.6.11.tar.gz -
Subject digest:
38555a6c3ab10e2fd7a4d22814130f605573676893d8239787c1e4ef6284790f - Sigstore transparency entry: 1092568699
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@1738689324427dab95a99f6e95344690af53785f -
Branch / Tag:
refs/tags/v0.6.11 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@1738689324427dab95a99f6e95344690af53785f -
Trigger Event:
push
-
Statement type:
File details
Details for the file nlsq-0.6.11-py3-none-any.whl.
File metadata
- Download URL: nlsq-0.6.11-py3-none-any.whl
- Upload date:
- Size: 649.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
31e1536680593ba792640ec258ffe090ffee9f20e6fe8d493a52b34264977e21
|
|
| MD5 |
0bb4b9eb68afc4e1e7adb633e2b13bc9
|
|
| BLAKE2b-256 |
39dbedeba7e5a4c6f939408814c56613b8d06e4d574ecebb7a6ee5484620cff8
|
Provenance
The following attestation bundles were made for nlsq-0.6.11-py3-none-any.whl:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.6.11-py3-none-any.whl -
Subject digest:
31e1536680593ba792640ec258ffe090ffee9f20e6fe8d493a52b34264977e21 - Sigstore transparency entry: 1092568706
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@1738689324427dab95a99f6e95344690af53785f -
Branch / Tag:
refs/tags/v0.6.11 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@1738689324427dab95a99f6e95344690af53785f -
Trigger Event:
push
-
Statement type: