Skip to main content

GPU/TPU accelerated nonlinear least-squares curve fitting using JAX

Project description

NLSQ logo

NLSQ: GPU-Accelerated Curve Fitting

Drop-in replacement for scipy.optimize.curve_fit with 150-270x speedup on GPU

PyPI version Documentation Python 3.12+ JAX ≥0.8.0 License: MIT

Documentation | Examples | API Reference | ArXiv Paper


What is NLSQ?

NLSQ is a nonlinear least squares curve fitting library built on JAX. It provides:

  • SciPy-compatible API - Same function signatures as scipy.optimize.curve_fit
  • GPU/TPU acceleration - JIT-compiled kernels via XLA
  • Automatic differentiation - No manual Jacobian calculations needed
  • Large dataset support - Handles 100M+ data points with streaming optimization
  • Interactive GUI - Native Qt desktop application for no-code curve fitting

Installation

Basic Installation (CPU-only)

pip install nlsq

This installs with CPU-only JAX (works on all platforms: Linux, macOS, Windows).

GPU Installation (Linux + System CUDA)

Performance Impact: 20-100x speedup for large datasets (>1M points)

Prerequisites:

  • NVIDIA GPU with SM >= 5.2 (Maxwell or newer)
  • System CUDA 12.x or 13.x installed
  • nvcc in PATH
Verify Prerequisites
# Check CUDA installation
nvcc --version
# Should show: Cuda compilation tools, release 12.x or 13.x

# Check GPU
nvidia-smi --query-gpu=name,compute_cap --format=csv,noheader
# Should show: GPU name and SM version (e.g., "8.9" for RTX 4090)

Option 1: Quick Install via Makefile (Recommended)

git clone https://github.com/imewei/NLSQ.git
cd NLSQ

# Auto-detect system CUDA version and install matching JAX
make install-jax-gpu

# Or explicitly choose CUDA version:
make install-jax-gpu-cuda13  # Requires system CUDA 13.x + SM >= 7.5
make install-jax-gpu-cuda12  # Requires system CUDA 12.x + SM >= 5.2

Option 2: Manual Installation

# For System CUDA 13.x (Turing and newer GPUs):
pip uninstall -y jax jaxlib
pip install "jax[cuda13-local]"

# For System CUDA 12.x (Maxwell and newer GPUs):
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]"

# Verify
python -c "import jax; print('Backend:', jax.default_backend())"
# Should show: Backend: gpu
GPU Compatibility Guide
GPU Generation Example GPUs SM Version CUDA 13 CUDA 12
Blackwell B100, B200 10.0 Yes Yes
Hopper H100, H200 9.0 Yes Yes
Ada Lovelace RTX 40xx, L40 8.9 Yes Yes
Ampere RTX 30xx, A100 8.x Yes Yes
Turing RTX 20xx, T4 7.5 Yes Yes
Volta V100, Titan V 7.0 No Yes
Pascal GTX 10xx, P100 6.x No Yes
Maxwell GTX 9xx, Titan X 5.x No Yes
Kepler GTX 7xx, K80 3.x No No

Recommendation: SM >= 7.5 (RTX 20xx or newer): Install CUDA 13 for best performance

GPU Troubleshooting

Issue: "nvcc not found"

# Option 1: Install CUDA toolkit
sudo apt install nvidia-cuda-toolkit  # Ubuntu/Debian

# Option 2: Add existing CUDA to PATH
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH

Issue: "CUDA version mismatch"

# Check your system CUDA version
nvcc --version

# Reinstall with correct package
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]"  # or cuda13-local

With GUI Support

GUI dependencies (PySide6, pyqtgraph) are included in the main package:

pip install nlsq    # GUI included
nlsq-gui            # Launch the desktop application

Quick Start

import numpy as np
import jax.numpy as jnp
from nlsq import curve_fit


# Define model function (use jax.numpy for GPU acceleration)
def exponential(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# Generate data
x = np.linspace(0, 4, 1000)
y = 2.5 * np.exp(-0.5 * x) + 1.0 + 0.1 * np.random.randn(len(x))

# Fit - same API as scipy.optimize.curve_fit
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5, 1.0])

print(f"Parameters: a={popt[0]:.3f}, b={popt[1]:.3f}, c={popt[2]:.3f}")
print(f"Uncertainties: {np.sqrt(np.diag(pcov))}")

Graphical User Interface

NLSQ includes an interactive GUI for curve fitting without writing code:

# Launch the GUI
nlsq-gui

The GUI provides:

  • Data Loading: Import CSV, ASCII, NPZ, or HDF5 files, or paste from clipboard
  • Model Selection: Choose from 8 built-in models, polynomials, or custom Python functions
  • Fitting Options: Guided presets (Fast/Robust/Quality) or advanced parameter control
  • Interactive Results: GPU-accelerated pyqtgraph plots with confidence bands and residuals
  • Export: Session bundles (ZIP), JSON/CSV results, and reproducible Python code

See the GUI User Guide for detailed documentation.

Key Features

Feature Description
Automatic Jacobians JAX's autodiff eliminates manual derivatives
Bounded optimization Trust Region Reflective and Levenberg-Marquardt
Large datasets Chunked and streaming optimizers for 100M+ points
Multi-start Global optimization with LHS/Sobol sampling
Workflow system 3 smart workflows: auto, auto_global, hpc with memory-aware strategy
CLI interface YAML-based workflows with nlsq fit and nlsq batch
Interactive GUI No-code curve fitting with Qt desktop application
Model Diagnostics Identifiability analysis, gradient health monitoring, sloppy model detection

Architecture

NLSQ is organized into well-separated layers (~74,000 lines):

┌─────────────────────────────────────────────────────────────────────────────┐
│                             USER INTERFACES                                 │
│  Qt GUI (PySide6)       CLI (Click)            Python API                   │
│  ├── 5-page workflow    ├── Model validation   ├── curve_fit(), fit()       │
│  ├── pyqtgraph plots    ├── Security auditing  ├── CurveFit class           │
│  └── Native desktop     └── Export formats     └── LargeDatasetFitter       │
├─────────────────────────────────────────────────────────────────────────────┤
│                        OPTIMIZATION ORCHESTRATION                           │
│  Workflow System         Global Optimization      Streaming Optimizer       │
│  ├── MemoryBudgetSel.    ├── MultiStartOrch.     ├── AdaptiveHybrid         │
│  ├── Strategy: STANDARD/ ├── TournamentSelect    ├── 4-Phase Pipeline:      │
│  │   CHUNKED/STREAMING   ├── LHS/Sobol/Halton    │   0: Normalization       │
│  └── Memory-based auto   └── Sampling            │   1: L-BFGS warmup       │
│                                                  │   2: Gauss-Newton        │
│  Orchestration (v0.6.4)                          └── 3: Denormalization     │
│  ├── DataPreprocessor    ├── OptimizationSel.                               │
│  ├── CovarianceComputer  └── StreamingCoord.                                │
├─────────────────────────────────────────────────────────────────────────────┤
│                          CORE OPTIMIZATION ENGINE                           │
│  curve_fit() ─→ CurveFit ─→ LeastSquares ─→ TrustRegionReflective           │
│       │              │            │                │                        │
│       ▼              ▼            ▼                ▼                        │
│  API Wrapper    Cache+State  Orchestrator+AD   SVD-based TRF                │
│  (SciPy-compat) (UnifiedCache) (AutoDiffJac)   (JIT-compiled)               │
├─────────────────────────────────────────────────────────────────────────────┤
│                          SUPPORT SUBSYSTEMS                                 │
│  stability/           precision/          caching/          diagnostics/    │
│  ├── NumericalGuard   ├── AlgorithmSel.   ├── UnifiedCache  ├── Identifiab. │
│  ├── SVD fallback     ├── ParamNormalizer  ├── SmartCache    ├── GradientMon │
│  └── Recovery         └── BoundsInfer.    └── MemoryMgr     └── PluginSys.  │
│  facades/ (v0.6.4)                                                          │
│  ├── OptimizationFacade  ├── StabilityFacade  └── DiagnosticsFacade         │
├─────────────────────────────────────────────────────────────────────────────┤
│                            INFRASTRUCTURE                                   │
│  interfaces/ (Protocols)     config.py (Singleton)    Security              │
│  ├── OptimizerProtocol       ├── JAXConfig            ├── safe_serialize    │
│  ├── CurveFitProtocol        ├── MemoryConfig         ├── model_validation  │
│  └── CacheProtocol           └── LargeDatasetConfig   └── resource_limits   │
├─────────────────────────────────────────────────────────────────────────────┤
│                         JAX RUNTIME (≥0.8.0)                                │
│  x64 enabled │ JIT compilation │ Autodiff │ GPU/TPU backend (optional)      │
└─────────────────────────────────────────────────────────────────────────────┘

See Architecture Documentation for detailed design.

Performance

NLSQ shows increasing speedups as dataset size grows:

Dataset Size SciPy (CPU) NLSQ (GPU) Speedup
10K points 15ms 2ms 7x
100K points 150ms 3ms 50x
1M points 1.5s 8ms 190x
10M points 15s 55ms 270x

Note: First call includes JIT compilation (~500ms). Subsequent calls reuse compiled kernels.

See Performance Guide for benchmarks.

Large Dataset Example

from nlsq import fit
import numpy as np
import jax.numpy as jnp


def model(x, a, b, c):
    return a * jnp.exp(-b * x) + c


# 50 million points
x = np.linspace(0, 10, 50_000_000)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))

# Auto-selects optimal strategy (chunked/streaming) based on memory
popt, pcov = fit(model, x, y, p0=[2.5, 0.6, 0.2], show_progress=True)

Advanced Usage

Multi-start global optimization
from nlsq import curve_fit

popt, pcov = curve_fit(
    model,
    x,
    y,
    p0=[1, 1, 1],
    bounds=([0, 0, 0], [10, 5, 10]),
    multistart=True,
    n_starts=20,
    sampler="lhs",
)
CMA-ES global optimization

For multi-scale parameter problems (parameters spanning >1000x scale ratio):

from nlsq.global_optimization import CMAESOptimizer, CMAESConfig

# Basic usage with BIPOP restart strategy
optimizer = CMAESOptimizer()
result = optimizer.fit(model, x, y, bounds=bounds)

# Memory-efficient configuration for large datasets (>10M points)
config = CMAESConfig(
    population_batch_size=4,  # Batch population evaluation
    data_chunk_size=50000,  # Stream data in chunks
)
optimizer = CMAESOptimizer(config=config)

Requires: pip install nlsq

Workflow presets (v0.6.3)

NLSQ v0.6.3 simplifies workflows to 3 smart options:

from nlsq import fit

# workflow="auto" (default) - memory-aware local optimization
result = fit(model, x, y, p0=[1, 1, 1])

# workflow="auto_global" - memory-aware global optimization (requires bounds)
result = fit(
    model, x, y, p0=[1, 1, 1], workflow="auto_global", bounds=([0, 0, 0], [10, 5, 10])
)

# workflow="hpc" - auto_global + checkpointing for HPC jobs
result = fit(
    model,
    x,
    y,
    p0=[1, 1, 1],
    workflow="hpc",
    bounds=([0, 0, 0], [10, 5, 10]),
    checkpoint_dir="/scratch/checkpoints",
)
Workflow Bounds Use Case
auto Optional Default. Local optimization with auto memory strategy
auto_global Required Multi-modal problems, unknown initial guess
hpc Required Long-running HPC jobs with checkpointing
Memory-based strategy selection
from nlsq.core.workflow import MemoryBudget, MemoryBudgetSelector

# Compute memory budget for your dataset
budget = MemoryBudget.compute(n_points=10_000_000, n_params=10)
print(f"Peak memory: {budget.peak_gb:.2f} GB")
print(f"Fits in memory: {budget.fits_in_memory}")

# Let selector choose strategy
selector = MemoryBudgetSelector(safety_factor=0.75)
strategy, config = selector.select(
    n_points=10_000_000,
    n_params=10,
    memory_limit_gb=16.0,  # Optional override
)
print(f"Selected: {strategy}")  # "streaming", "chunked", or "standard"
Numerical stability
from nlsq import curve_fit

# Auto-detect and fix numerical issues
popt, pcov = curve_fit(model, x, y, p0=p0, stability="auto")
Memory management
from nlsq import MemoryConfig, memory_context, CurveFit

config = MemoryConfig(memory_limit_gb=8.0)

with memory_context(config):
    cf = CurveFit()
    popt, pcov = cf.curve_fit(model, x, y, p0=p0)
Command-line interface
# Launch GUI
nlsq gui

# Copy configuration templates
nlsq config

# Single workflow
nlsq fit experiment.yaml

# Output results to stdout (for piping)
nlsq fit experiment.yaml --stdout | jq '.popt'

# Batch processing
nlsq batch configs/*.yaml --summary results.json --workers 4

# System info
nlsq info

See CLI Reference for full documentation.

Model health diagnostics
from nlsq.diagnostics import (
    DiagnosticsConfig,
    IdentifiabilityAnalyzer,
    GradientMonitor,
    ParameterSensitivityAnalyzer,
    create_health_report,
)

# Analyze parameter identifiability from Jacobian
config = DiagnosticsConfig(condition_threshold=1e8, correlation_threshold=0.95)
analyzer = IdentifiabilityAnalyzer(config)
report = analyzer.analyze(result.jac)

print(f"Condition number: {report.condition_number:.2e}")
print(f"Numerical rank: {report.numerical_rank}/{report.n_params}")

# Monitor gradient health during optimization
monitor = GradientMonitor(config)
callback = monitor.create_callback()
result = curve_fit(model, x, y, p0=p0, callback=callback)
grad_report = monitor.get_report()

# Analyze parameter sensitivity spectrum
sensitivity_analyzer = ParameterSensitivityAnalyzer(config)
sensitivity_report = sensitivity_analyzer.analyze(result.jac)
print(f"Is sloppy: {sensitivity_report.is_sloppy}")
print(f"Effective dimensionality: {sensitivity_report.effective_dimensionality:.1f}")

# Create aggregated health report
health_report = create_health_report(
    identifiability=report,
    gradient_health=grad_report,
    sloppy_model=sensitivity_report,
)
print(health_report.summary())

See Diagnostics API for complete documentation.

See Advanced User Guide for complete documentation.

Examples

Start with the Interactive Tutorial on Google Colab.

By topic:

See examples/README.md for the full index.

Requirements

  • Python 3.12+
  • JAX 0.8.0+
  • NumPy 2.2+
  • SciPy 1.16.0+

GUI requirements: PySide6 6.10.0+, pyqtgraph 0.14.0+

GPU support (Linux only): CUDA 12.x-13.x, NVIDIA driver >= 525

Citation

If you use NLSQ in your research, please cite:

@software{nlsq2025,
  title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
  author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  year={2025},
  url={https://github.com/imewei/NLSQ}
}

@article{jaxfit2022,
  title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
  author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
  journal={arXiv preprint arXiv:2208.12187},
  year={2022}
}

Acknowledgments

NLSQ is an enhanced fork of JAXFit by Lucas R. Hofer, Milan Krstajic, and Robert P. Smith. We gratefully acknowledge their foundational work.

License

MIT License - see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlsq-0.6.11.tar.gz (2.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nlsq-0.6.11-py3-none-any.whl (649.9 kB view details)

Uploaded Python 3

File details

Details for the file nlsq-0.6.11.tar.gz.

File metadata

  • Download URL: nlsq-0.6.11.tar.gz
  • Upload date:
  • Size: 2.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.6.11.tar.gz
Algorithm Hash digest
SHA256 38555a6c3ab10e2fd7a4d22814130f605573676893d8239787c1e4ef6284790f
MD5 98f5cb0939c0ae68081b0227a161f7ff
BLAKE2b-256 b9c78828779c5d951950570e3f774cf35a6df05f58d6f22c4f462c00e393fcda

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.6.11.tar.gz:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nlsq-0.6.11-py3-none-any.whl.

File metadata

  • Download URL: nlsq-0.6.11-py3-none-any.whl
  • Upload date:
  • Size: 649.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for nlsq-0.6.11-py3-none-any.whl
Algorithm Hash digest
SHA256 31e1536680593ba792640ec258ffe090ffee9f20e6fe8d493a52b34264977e21
MD5 0bb4b9eb68afc4e1e7adb633e2b13bc9
BLAKE2b-256 39dbedeba7e5a4c6f939408814c56613b8d06e4d574ecebb7a6ee5484620cff8

See more details on using hashes here.

Provenance

The following attestation bundles were made for nlsq-0.6.11-py3-none-any.whl:

Publisher: release.yml on imewei/NLSQ

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page