GPU/TPU accelerated nonlinear least-squares curve fitting using JAX
Project description
NLSQ: Nonlinear Least Squares Curve Fitting
Quickstart | Install guide | ArXiv Paper | Documentation | Examples
Acknowledgments
NLSQ is an enhanced fork of JAXFit, originally developed by Lucas R. Hofer, Milan Krstajić, and Robert P. Smith. We gratefully acknowledge their foundational work on GPU-accelerated curve fitting with JAX. The original JAXFit paper: arXiv:2208.12187.
What is NLSQ?
NLSQ builds upon JAXFit's foundation, implementing SciPy's nonlinear least squares curve fitting algorithms using JAX for GPU/TPU acceleration. This fork adds significant optimizations, enhanced testing, improved API design, and advanced features for production use. Fit functions are written in Python without CUDA programming.
NLSQ uses JAX's automatic differentiation to calculate Jacobians automatically, eliminating the need for manual partial derivatives or numerical approximation.
NLSQ provides a drop-in replacement for SciPy's curve_fit with production-focused performance and scale.
Key Capabilities
- Fast fits with JAX: JIT-compiled kernels and automatic differentiation for Jacobians.
- SciPy-compatible APIs: Trust Region Reflective and Levenberg-Marquardt, plus bounds.
- Robustness tools: robust losses, stability checks, and recovery strategies.
- Scale to huge datasets: chunked and streaming optimizers with progress reporting.
- Workflow automation:
fit()chooses tiers and presets based on data size/memory. - Caching & memory controls: JIT cache reuse, mixed precision fallback, and cleanup.
Global Optimization (v0.3.3+)
- Multi-start optimization with LHS/Sobol/Halton sampling.
- Presets: fast, robust, global, thorough, streaming.
- Automatic bounds inference and convergence analysis.
Workflow System (v0.3.4+)
- Tiers: STANDARD, CHUNKED, STREAMING, STREAMING_CHECKPOINT.
- Goals: FAST, ROBUST, GLOBAL, MEMORY_EFFICIENT, QUALITY.
- Config: YAML + env overrides, checkpointing, auto-resume.
Basic Usage
import numpy as np
from nlsq import CurveFit
# Define your fit function
def linear(x, m, b):
return m * x + b
# Prepare data
x = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = np.array([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20])
# Perform the fit
cf = CurveFit()
popt, pcov = cf.curve_fit(linear, x, y)
print(f"Fitted parameters: m={popt[0]:.2f}, b={popt[1]:.2f}")
NLSQ leverages JAX's just-in-time (JIT) compilation to XLA for GPU/TPU acceleration. Fit functions must be JIT-compilable. For functions using special operations, use JAX's numpy:
import jax.numpy as jnp
import numpy as np
from nlsq import CurveFit
# Define exponential fit function using JAX numpy
def exponential(x, a, b):
return jnp.exp(a * x) + b
# Generate synthetic data
x = np.linspace(0, 4, 50)
y_true = np.exp(0.5 * x) + 2.0
y = y_true + 0.1 * np.random.normal(size=len(x))
# Fit with initial guess
cf = CurveFit()
popt, pcov = cf.curve_fit(exponential, x, y, p0=[0.5, 2.0])
print(f"Fitted: a={popt[0]:.3f}, b={popt[1]:.3f}")
# Get parameter uncertainties from covariance
perr = np.sqrt(np.diag(pcov))
print(f"Uncertainties: σ_a={perr[0]:.3f}, σ_b={perr[1]:.3f}")
For more complex fit functions there are a few JIT function caveats (see Current gotchas) such as avoiding control code within the fit function (see JAX's sharp edges article for a more in-depth look at JAX specific caveats).
Contents
- Quickstart: Colab in the Cloud
- Large Dataset Support
- Current gotchas
- Installation
- Citing NLSQ
- Reference documentation
- Examples Gallery
Quickstart: Colab in the Cloud
The easiest way to test out NLSQ is using a Colab notebook connected to a Google Cloud GPU. JAX comes pre-installed so you'll be able to start fitting right away.
Tutorial notebooks:
- Interactive Tutorial: Beginner to Advanced (recommended start! ⭐)
- The basics: fitting basic functions with NLSQ
- Fitting 2D images with NLSQ
- Large dataset fitting demonstration
Performance Benchmarks
NLSQ is optimized for large datasets and shows increasing speedups as data size grows. For full benchmark tables and methodology, see the Performance Guide.
What to expect:
- First call includes JIT compilation; subsequent calls are much faster.
- GPU speedups grow with dataset size (10K+ points show the biggest wins).
- Chunked/streaming workflows avoid memory blowups for very large datasets.
Examples Gallery
Start with the curated entry points below, or browse the full index in
examples/README.md.
Recommended starters:
Browse by category:
- Application gallery (biology, chemistry, engineering, physics)
- Feature demos (callbacks, diagnostics, result helpers)
- Streaming & fault tolerance workflows
- Scripts mirroring every notebook
See examples/README.md for the full index.
Large Dataset Support
Note: The examples below are tested with NLSQ v0.1.1+ (NumPy 2.0+, JAX 0.8.0, Python 3.12+) Last validated: 2025-11-19 | Test suite | CI Status
NLSQ includes advanced features for handling very large datasets (20M+ points) that may not fit in memory:
Automatic Large Dataset Handling with curve_fit_large
from nlsq import curve_fit_large, estimate_memory_requirements
import jax.numpy as jnp
import numpy as np
# Check memory requirements for your dataset
n_points = 50_000_000 # 50 million points
n_params = 3
stats = estimate_memory_requirements(n_points, n_params)
print(f"Memory required: {stats.total_memory_estimate_gb:.2f} GB")
print(f"Recommended chunks: {stats.n_chunks}")
# Generate large dataset
x = np.linspace(0, 10, n_points)
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, n_points)
# Define fit function using JAX numpy
def exponential(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Use curve_fit_large for automatic dataset size detection and chunking
popt, pcov = curve_fit_large(
exponential,
x,
y,
p0=[2.5, 0.6, 0.2],
memory_limit_gb=4.0, # Automatic chunking if needed
show_progress=True, # Progress bar for large datasets
)
print(f"Fitted parameters: {popt}")
print(f"Parameter uncertainties: {np.sqrt(np.diag(pcov))}")
Advanced Large Dataset Fitting Options
from nlsq import LargeDatasetFitter, fit_large_dataset, LDMemoryConfig
import jax.numpy as jnp
# Option 1: Use the convenience function for simple cases
result = fit_large_dataset(
exponential,
x,
y,
p0=[2.5, 0.6, 0.2],
memory_limit_gb=4.0,
show_progress=True, # Progress bar for long fits
)
# Option 2: Use LargeDatasetFitter for more control
config = LDMemoryConfig(
memory_limit_gb=4.0,
min_chunk_size=10000,
max_chunk_size=1000000,
# Streaming optimization is automatic for very large datasets
# No manual configuration needed - handles unlimited data with zero accuracy loss
use_streaming=True, # Enable streaming for datasets > memory limit
streaming_batch_size=50000, # Mini-batch size for streaming optimizer
)
fitter = LargeDatasetFitter(config=config)
result = fitter.fit_with_progress(
exponential,
x,
y,
p0=[2.5, 0.6, 0.2],
)
print(f"Fitted parameters: {result.popt}")
# Note: success_rate and n_chunks are only available for multi-chunk fits
# print(f"Covariance matrix: {result.pcov}")
Sparse Jacobian Optimization
For problems with sparse Jacobian structure (e.g., fitting multiple independent components):
from nlsq import SparseJacobianComputer
# ... (assumes func, p0, x_sample defined from previous example)
# Automatically detect and exploit sparsity
sparse_computer = SparseJacobianComputer(sparsity_threshold=0.01)
pattern, sparsity = sparse_computer.detect_sparsity_pattern(func, p0, x_sample)
if sparsity > 0.1: # If more than 10% sparse
print(f"Jacobian is {sparsity:.1%} sparse")
# Optimization will automatically use sparse methods
Streaming Optimizer for Unlimited Datasets
For datasets that don't fit in memory or are generated on-the-fly:
from nlsq import StreamingOptimizer, StreamingConfig
# Configure streaming optimization
config = StreamingConfig(batch_size=10000, max_epochs=100, convergence_tol=1e-6)
optimizer = StreamingOptimizer(config)
# Stream data from file or generator
result = optimizer.fit_streaming(func, data_generator, p0=p0)
Adaptive Hybrid Streaming Optimizer (v0.3.0+)
Four-phase hybrid optimizer combining parameter normalization, Adam warmup, streaming Gauss-Newton, and exact covariance computation:
from nlsq import AdaptiveHybridStreamingOptimizer, HybridStreamingConfig
import jax.numpy as jnp
# Configure with presets: aggressive, conservative, or memory_optimized
config = HybridStreamingConfig.aggressive() # Fast convergence
# config = HybridStreamingConfig.conservative() # Higher quality
# config = HybridStreamingConfig.memory_optimized() # Lower memory
optimizer = AdaptiveHybridStreamingOptimizer(config)
# Define model
def model(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Fit with bounds-based normalization (addresses gradient imbalance)
result = optimizer.fit(
model=model,
x_data=x_data,
y_data=y_data,
p0=[2.0, 0.5, 0.3],
bounds=(jnp.array([1.0, 0.1, 0.0]), jnp.array([10.0, 1.0, 2.0])),
)
When to use Adaptive Hybrid Streaming:
- Parameters span many orders of magnitude (gradient imbalance)
- Large datasets (100K+ points) with memory constraints
- Need production-quality uncertainty estimates
- Standard optimizers converge slowly near optimum
4-Layer Defense Strategy (v0.3.6+)
The hybrid streaming optimizer includes a 4-layer defense strategy that prevents Adam warmup divergence when initial parameters are already near optimal:
from nlsq import (
curve_fit,
HybridStreamingConfig,
get_defense_telemetry,
reset_defense_telemetry,
)
import jax.numpy as jnp
def model(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Defense layers are enabled by default
popt, pcov = curve_fit(
model,
x,
y,
p0=[2.0, 0.5, 1.0],
method="hybrid_streaming",
)
# Monitor defense layer activations
telemetry = get_defense_telemetry()
print(telemetry.get_summary())
print(telemetry.get_trigger_rates())
The 4 Defense Layers:
| Layer | Name | Description |
|---|---|---|
| 1 | Warm Start Detection | Skips warmup if initial loss < 1% of data variance |
| 2 | Adaptive Learning Rate | Scales LR based on initial fit quality (1e-6 to 0.001) |
| 3 | Cost-Increase Guard | Aborts if loss increases > 5% from initial |
| 4 | Step Clipping | Limits parameter update magnitude (max norm 0.1) |
Defense Presets:
# For warm-start refinement (stricter thresholds)
config = HybridStreamingConfig.defense_strict()
# For exploration from poor initial guess
config = HybridStreamingConfig.defense_relaxed()
# For regression testing (pre-0.3.6 behavior)
config = HybridStreamingConfig.defense_disabled()
# For scientific computing
config = HybridStreamingConfig.scientific_default()
popt, pcov = curve_fit(model, x, y, method="hybrid_streaming", config=config)
Production Monitoring:
# Reset telemetry before batch processing
reset_defense_telemetry()
for dataset in datasets:
curve_fit(model, x, y, method="hybrid_streaming")
# Export Prometheus-compatible metrics
metrics = get_defense_telemetry().export_metrics()
See the Defense Layers Guide for detailed documentation.
Key Features for Large Datasets:
- Automatic Size Detection:
curve_fit_largeautomatically switches between standard and chunked fitting - Memory Estimation: Predict memory requirements before fitting
- Intelligent Chunking: Improved algorithm with <1% error for well-conditioned problems
- Progress Reporting: Track progress during long-running fits
- JAX Tracing Support: Compatible with functions having 15+ parameters
- Sparse Optimization: Exploit sparsity in Jacobian matrices
- Streaming Support: Process data that doesn't fit in memory
- Memory-Efficient Solvers: CG and LSQR solvers for reduced memory usage
- Adaptive Convergence: Early stopping when parameters stabilize
For more details, see the large dataset guide and API documentation.
Advanced Features
Memory Management & Configuration
NLSQ provides memory management with context-based configuration:
from nlsq import MemoryConfig, memory_context, get_memory_config
import numpy as np
# Configure memory settings
config = MemoryConfig(
memory_limit_gb=8.0,
enable_mixed_precision_fallback=True,
safety_factor=0.8,
progress_reporting=True,
)
# Use memory context for temporary settings
with memory_context(config):
# Memory-optimized fitting
cf = CurveFit()
popt, pcov = cf.curve_fit(func, x, y, p0=p0)
# Check current memory configuration
current_config = get_memory_config()
print(f"Memory limit: {current_config.memory_limit_gb} GB")
print(f"Mixed precision fallback: {current_config.enable_mixed_precision_fallback}")
Mixed Precision Fallback
NLSQ includes automatic mixed precision management that provides 50% memory savings while maintaining numerical accuracy:
from nlsq import curve_fit
from nlsq.config import configure_mixed_precision
import jax.numpy as jnp
# Enable mixed precision with custom settings
configure_mixed_precision(
enable=True,
max_degradation_iterations=5, # Grace period before upgrading
gradient_explosion_threshold=1e10,
verbose=True, # Show precision upgrades in logs
)
# Define model function
def exponential(x, a, b):
return a * jnp.exp(-b * x)
# Fit - starts in float32, automatically upgrades to float64 if needed
popt, pcov = curve_fit(exponential, x, y, p0=[2.0, 0.5])
Key Features:
- Automatic float32 → float64 upgrade when precision issues detected
- 50% memory savings when using float32
- Zero-iteration loss during precision upgrades (state fully preserved)
- Intelligent fallback to relaxed float32 if float64 fails
- Environment variable configuration for production deployment
Configuration Options:
# Programmatic configuration
configure_mixed_precision(
enable=True,
max_degradation_iterations=5,
gradient_explosion_threshold=1e10,
precision_limit_threshold=1e-7,
tolerance_relaxation_factor=10.0,
verbose=False,
)
# Or use environment variables
# export NLSQ_MIXED_PRECISION_VERBOSE=1
# export NLSQ_GRADIENT_EXPLOSION_THRESHOLD=1e8
# export NLSQ_MAX_DEGRADATION_ITERATIONS=3
Algorithm Selection
NLSQ can select the best algorithm based on problem characteristics:
from nlsq.algorithm_selector import AlgorithmSelector, auto_select_algorithm
from nlsq import curve_fit
import jax.numpy as jnp
# Define your model
def model_nonlinear(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Auto-select best algorithm
recommendations = auto_select_algorithm(
f=model_nonlinear, xdata=x, ydata=y, p0=[1.0, 0.5, 0.1]
)
# Use recommended algorithm
method = recommendations.get("algorithm", "trf")
popt, pcov = curve_fit(model_nonlinear, x, y, p0=[1.0, 0.5, 0.1], method=method)
print(f"Selected algorithm: {method}")
print(f"Fitted parameters: {popt}")
Multi-Start Global Optimization (v0.3.3+)
NLSQ provides multi-start optimization with Latin Hypercube Sampling (LHS) for finding global optima in problems with multiple local minima:
from nlsq import fit, curve_fit
from nlsq.global_optimization import MultiStartOrchestrator, GlobalOptimizationConfig
import jax.numpy as jnp
import numpy as np
# Define model with multiple local minima
def multimodal_model(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Generate data
x = np.linspace(0, 5, 100)
y = 3.0 * np.exp(-0.5 * x) + 1.0 + np.random.normal(0, 0.1, 100)
# Option 1: Use fit() with preset (simplest)
popt, pcov = fit(
multimodal_model, x, y, preset="robust", bounds=([0, 0, 0], [10, 5, 10])
)
# Option 2: Use curve_fit() with multi-start parameters
popt, pcov = curve_fit(
multimodal_model,
x,
y,
p0=[1, 1, 1],
bounds=([0, 0, 0], [10, 5, 10]),
multistart=True,
n_starts=10,
sampler="lhs",
)
# Option 3: Use MultiStartOrchestrator for full control
orchestrator = MultiStartOrchestrator.from_preset("global")
result = orchestrator.fit(multimodal_model, x, y, bounds=([0, 0, 0], [10, 5, 10]))
print(f"Best params: {result.popt}")
print(f"Multi-start diagnostics: {result.multistart_diagnostics}")
Preset Configurations:
| Preset | n_starts | Description |
|---|---|---|
'fast' |
0 | Single-start (disabled) for maximum speed |
'robust' |
5 | Light multi-start for robustness |
'global' |
20 | Thorough global search |
'thorough' |
50 | Exhaustive search |
'streaming' |
10 | Tournament selection for large datasets |
Sampling Strategies:
'lhs'(Latin Hypercube Sampling): Best coverage guarantees, recommended for most cases'sobol': Sobol quasi-random sequence, deterministic and reproducible'halton': Halton sequence using prime bases
Tournament Selection for Large Datasets:
For streaming/large datasets where evaluating all candidates is expensive:
from nlsq.global_optimization import TournamentSelector, GlobalOptimizationConfig
from nlsq.global_optimization import latin_hypercube_sample
# Generate candidates using LHS
candidates = latin_hypercube_sample(20, 3) # 20 candidates, 3 parameters
# Configure tournament
config = GlobalOptimizationConfig(
n_starts=20,
elimination_rounds=3,
elimination_fraction=0.5, # Eliminate half each round
batches_per_round=50,
)
selector = TournamentSelector(candidates, config)
# Run tournament on streaming data
def data_generator():
for _ in range(200):
yield x_batch, y_batch
best_candidates = selector.run_tournament(data_generator(), model, top_m=1)
Workflow System (v0.3.4+)
The unified fit() function automatically selects the optimal fitting strategy based on dataset size and available memory:
from nlsq import fit, WorkflowConfig, WorkflowTier, OptimizationGoal
import jax.numpy as jnp
import numpy as np
# Define model
def exponential(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Generate data
x = np.linspace(0, 10, 1_000_000) # 1M points
y = 2.0 * np.exp(-0.5 * x) + 0.3 + np.random.normal(0, 0.05, len(x))
# Option 1: Auto-select workflow based on dataset/memory (simplest)
popt, pcov = fit(exponential, x, y, p0=[2.5, 0.6, 0.2])
# Option 2: Use a named preset
popt, pcov = fit(exponential, x, y, p0=[2.5, 0.6, 0.2], preset="robust")
# Option 3: Specify workflow explicitly
popt, pcov = fit(exponential, x, y, p0=[2.5, 0.6, 0.2], workflow=WorkflowTier.STREAMING)
# Option 4: Custom configuration
config = WorkflowConfig(
tier=WorkflowTier.CHUNKED,
goal=OptimizationGoal.QUALITY,
memory_limit_gb=8.0,
enable_checkpointing=True,
)
popt, pcov = fit(exponential, x, y, p0=[2.5, 0.6, 0.2], config=config)
Workflow Tiers:
| Tier | Description | Dataset Size |
|---|---|---|
STANDARD |
Direct optimization, no chunking | < 100K points |
CHUNKED |
Memory-efficient chunked processing | 100K - 10M points |
STREAMING |
Streaming optimization for huge datasets | > 10M points |
STREAMING_CHECKPOINT |
Streaming with automatic checkpointing | > 10M points, long jobs |
Workflow Presets:
| Preset | Goal | Description |
|---|---|---|
'fast' |
FAST | Minimum iterations, relaxed tolerances |
'robust' |
ROBUST | Multi-start with 5 starting points |
'global' |
GLOBAL | Thorough global search with 20 starts |
'memory_efficient' |
MEMORY_EFFICIENT | Aggressive chunking, streaming fallback |
'quality' |
QUALITY | Tight tolerances, validation passes |
'hpc' |
ROBUST | PBS Pro cluster configuration |
'streaming' |
MEMORY_EFFICIENT | Tournament selection for streaming |
YAML Configuration:
Create nlsq.yaml in your project directory:
workflow:
goal: robust
memory_limit_gb: 16.0
enable_checkpointing: true
checkpoint_dir: ./checkpoints
tolerances:
ftol: 1e-10
xtol: 1e-10
gtol: 1e-10
cluster:
type: pbs
nodes: 4
gpus_per_node: 2
Environment Variable Overrides:
export NLSQ_WORKFLOW_GOAL=quality
export NLSQ_MEMORY_LIMIT_GB=32.0
export NLSQ_CHECKPOINT_DIR=/scratch/checkpoints
Diagnostics & Monitoring
Monitor optimization progress:
from nlsq import ConvergenceMonitor, CurveFit
from nlsq.diagnostics import OptimizationDiagnostics
import numpy as np
# Create convergence monitor
monitor = ConvergenceMonitor(window_size=10, sensitivity=1.0)
# Use CurveFit with stability features
cf = CurveFit(enable_stability=True, enable_recovery=True)
# Perform fitting
popt, pcov = cf.curve_fit(func, x, y, p0=p0)
print(f"Fitted parameters: {popt}")
# For detailed diagnostics, create separate diagnostics object
diagnostics = OptimizationDiagnostics()
# (diagnostics would be populated during optimization)
Numerical Stability Mode
NLSQ provides automatic numerical stability monitoring and correction to prevent optimization divergence:
from nlsq import curve_fit
import jax.numpy as jnp
import numpy as np
# Define a model function
def exponential(x, a, b, c):
return a * jnp.exp(-b * x) + c
# Generate data with challenging characteristics
x = np.linspace(0, 1e6, 1000) # Large x-range can cause ill-conditioning
y = 2.5 * np.exp(-0.5 * x) + 1.0
# Option 1: stability='check' - warn about issues but don't fix
popt, pcov = curve_fit(exponential, x, y, p0=[2.5, 0.5, 1.0], stability="check")
# Option 2: stability='auto' - automatically detect and fix issues (recommended)
popt, pcov = curve_fit(exponential, x, y, p0=[2.5, 0.5, 1.0], stability="auto")
# Option 3: stability=False - disable stability checks (default)
popt, pcov = curve_fit(exponential, x, y, p0=[2.5, 0.5, 1.0], stability=False)
Stability Modes:
| Mode | Behavior | Use Case |
|---|---|---|
stability=False |
No checks (default) | Simple problems, maximum speed |
stability='check' |
Warn about issues | Debugging, identify problems |
stability='auto' |
Auto-detect and fix | Production use, challenging problems |
Key Features:
- NaN/Inf Detection: Automatically replaces invalid values in Jacobian
- Condition Number Monitoring: Detects ill-conditioned problems
- Data Rescaling: Optional rescaling of data to improve conditioning
- SVD Skip for Large Jacobians: Avoids expensive SVD computation for >10M elements
Physics Applications (XPCS, scattering, etc.):
For physics applications where data must maintain physical units:
# Preserve physical units (don't rescale time delays, scattering vectors, etc.)
popt, pcov = curve_fit(
g2_model,
tau,
y,
p0=[1.0, 0.3, 100.0],
stability="auto",
rescale_data=False, # Preserve physical units
)
Large Jacobian Optimization:
For large datasets (>10M Jacobian elements), SVD computation is automatically skipped to prevent performance degradation:
# Custom SVD threshold (default: 10M elements)
popt, pcov = curve_fit(
model,
x_large,
y_large,
p0=p0,
stability="auto",
max_jacobian_elements_for_svd=5_000_000, # Skip SVD above 5M elements
)
Performance Impact:
stability=False: No overheadstability='check': ~1ms overhead for 1M pointsstability='auto': ~1-5ms overhead, prevents divergence
For more details, see the Stability Guide.
Caching System
Optimize performance with caching:
from nlsq import SmartCache, cached_function, curve_fit
import jax.numpy as jnp
# Configure caching
cache = SmartCache(max_memory_items=1000, disk_cache_enabled=True)
# Define fit function (caching happens at the JIT level)
def exponential(x, a, b):
return a * jnp.exp(-b * x)
# First fit - compiles function
popt1, pcov1 = curve_fit(exponential, x1, y1, p0=[1.0, 0.1])
# Second fit - reuses JIT compilation from first fit
popt2, pcov2 = curve_fit(exponential, x2, y2, p0=[1.2, 0.15])
# Check cache statistics
stats = cache.get_stats()
print(f"Cache hit rate: {stats['hit_rate']:.1%}")
Optimization Recovery & Fallback
Error handling with recovery from failed optimizations:
from nlsq import OptimizationRecovery, CurveFit, curve_fit
import numpy as np
# CurveFit with built-in recovery enabled
cf = CurveFit(enable_recovery=True)
try:
popt, pcov = cf.curve_fit(func, x, y, p0=p0_initial)
print(f"Fitted parameters: {popt}")
except Exception as e:
print(f"Optimization failed: {e}")
# Manual recovery with OptimizationRecovery
recovery = OptimizationRecovery(max_retries=3, enable_diagnostics=True)
# Recovery provides automatic fallback strategies
popt, pcov = curve_fit(func, x, y, p0=p0_initial)
Input Validation & Error Handling
Input validation for robust operation:
from nlsq import InputValidator, curve_fit
import numpy as np
# Create validator
validator = InputValidator(fast_mode=True)
# Validate inputs before fitting
warnings, errors, clean_x, clean_y = validator.validate_curve_fit_inputs(
f=func, xdata=x, ydata=y, p0=p0
)
if errors:
print(f"Validation errors: {errors}")
else:
# Use validated data
popt, pcov = curve_fit(func, clean_x, clean_y, p0=p0)
print(f"Fitted parameters: {popt}")
Performance Optimizations (v0.3.0-beta.2)
NLSQ v0.3.0-beta.2 introduces three major performance optimizations for Phase 1 Priority 2:
Adaptive Memory Reuse
12.5% peak memory reduction through intelligent memory pooling:
from nlsq import curve_fit
from nlsq.memory_manager import MemoryManager
# Automatic memory reuse with size-class bucketing
manager = MemoryManager(enable_pooling=True, enable_stats=True)
# Fit with memory pooling
popt, pcov = curve_fit(model, x, y, p0=[1.0, 0.5])
# Check memory statistics
stats = manager.get_stats()
print(f"Memory pool reuse rate: {stats['reuse_rate']:.1%}") # Typically 90%
print(f"Peak memory: {stats['peak_memory_mb']:.2f} MB")
Key Features:
- Size-class bucketing (1KB/10KB/100KB) for 5x better reuse
- Adaptive safety factor (1.2 → 1.05) based on problem characteristics
- 90% memory pool reuse rate achieved
- Zero-copy optimization for reduced malloc/free overhead
Sparse Jacobian Activation
Automatic sparse pattern detection for computational efficiency:
from nlsq.sparse_jacobian import SparseJacobianComputer
# Automatic sparsity detection
computer = SparseJacobianComputer(sparsity_threshold=0.01)
# Detect sparsity pattern
pattern, sparsity = computer.detect_sparsity_pattern(model, p0, x_sample)
if sparsity > 0.1: # More than 10% sparse
print(f"Jacobian is {sparsity:.1%} sparse")
print("Sparse optimizations automatically enabled")
Benefits:
- Detects sparse patterns (>70% zeros) automatically
- Auto-enables sparse-aware optimizations when beneficial
- Phase 1 infrastructure complete; Phase 2 will deliver 5-50x speedup for sparse problems
Streaming Batch Padding
Zero JIT recompiles after warmup for streaming optimization:
from nlsq import StreamingOptimizer, StreamingConfig
# Enable batch padding for zero recompiles
config = StreamingConfig(
batch_size=100, use_batch_padding=True, batch_padding_multiple=16 # Default on GPU
)
optimizer = StreamingOptimizer(config)
# First few batches compile, then zero recompiles
result = optimizer.fit_streaming(data_generator, model, p0=[1.0, 0.5])
# Check diagnostics
print(f"Warmup batches: {result['warmup_batches']}")
print(f"Recompiles after warmup: {result['recompiles_after_warmup']}") # 0
Performance:
- Eliminates JIT thrashing between streaming batches
- Device-aware auto-selection (GPU default, dynamic on CPU)
- Expected 5-15% GPU throughput improvement
Host-Device Transfer Profiling (v0.3.0-beta.3)
Comprehensive profiling and validation infrastructure for monitoring GPU-CPU transfers:
from nlsq.profiling import profile_optimization, analyze_source_transfers
from nlsq import curve_fit
import jax.numpy as jnp
# Profile optimization performance
with profile_optimization() as metrics:
popt, pcov = curve_fit(model, x, y, p0=[1.0, 0.5])
print(f"Total time: {metrics.total_time_sec:.3f}s")
print(f"Average iteration: {metrics.avg_iteration_time_ms:.2f}ms")
# Static analysis of transfer patterns
with open("mymodule.py") as f:
code = f.read()
analysis = analyze_source_transfers(code)
print(f"Potential transfers: {analysis['total_potential_transfers']}")
Key Features:
- Async Logging: JAX-aware logging eliminates GPU-CPU blocking (<5% overhead)
- JAX Profiler Integration: Runtime transfer measurement with
jax.profiler.trace() - Static Analysis: Detect
np.array(),np.asarray(),.block_until_ready()patterns - Performance Baselines: Automated baseline generation and CI/CD regression gates
- Input Validation: Type checking for all profiling functions
Performance Regression Protection:
from nlsq import curve_fit
# Automatic regression detection in CI
# Tests fail if performance degrades >10% vs baseline
# See tests/test_performance_regression.py
Async Logging Benefits:
- Zero host-device blocking during optimization
- Verbosity control (0=off, 1=every 10th, 2=all iterations)
- Automatic JAX array detection
- Non-blocking callbacks via
jax.debug.callback
For detailed performance analysis, see the Performance Guide.
Current gotchas
Full disclosure we've copied most of this from the JAX repo, but NLSQ inherits JAX's idiosyncrasies and so the "gotchas" are mostly the same.
Automatic Precision Management (v0.2.0+)
NLSQ automatically manages numerical precision for optimal performance and memory usage:
- Default: Float32 (single precision) for memory efficiency
- Automatic upgrade: Float32 → Float64 when precision issues detected
- Memory savings: Up to 50% by starting in float32
- No manual configuration needed for most use cases
NLSQ starts with single precision (float32) for memory efficiency. The mixed precision system will automatically upgrade to float64 if convergence stalls or precision issues are detected.
Advanced users can manually control precision or disable automatic fallback:
from nlsq import curve_fit
from nlsq.mixed_precision import MixedPrecisionConfig
# Disable automatic fallback (strict float64)
config = MixedPrecisionConfig(enable_fallback=False)
popt, pcov = curve_fit(f, xdata, ydata, mixed_precision_config=config)
# Or manually enable x64 before importing NLSQ
from jax import config
config.update("jax_enable_x64", True)
See the Mixed Precision guide for advanced configuration options.
Other caveats
Below are some more things to be careful of, but a full list can be found in JAX's Gotchas Notebook. Some standouts:
- JAX transformations only work on pure functions, which don't have side-effects and respect referential transparency (i.e. object identity testing with
isisn't preserved). If you use a JAX transformation on an impure Python function, you might see an error likeException: Can't lift Traced...orException: Different traces at same level. - In-place mutating updates of arrays, like
x[i] += y, aren't supported, but there are functional alternatives. Under ajit, those functional alternatives will reuse buffers in-place automatically. - Some transformations, like
jit, constrain how you can use Python control flow. You'll always get loud errors if something goes wrong. You might have to use jit's static_argnums parameter, structured control flow primitives like lax.scan. - Some of NumPy's dtype promotion semantics involving a mix of Python scalars and NumPy types aren't preserved, namely
np.add(1, np.array([2], np.float32)).dtypeisfloat64rather thanfloat32. - If you're looking for convolution operators, they're in the
jax.laxpackage.
Installation
Requirements
- Python 3.12 or higher (3.13 also supported)
- JAX 0.8.0 (locked version)
- NumPy 2.0+ ⚠️ Breaking change from NumPy 1.x (tested with 2.3.4)
- SciPy 1.14.0+ (tested with 1.16.2)
Platform Support
| Platform | GPU Support | Performance | Notes |
|---|---|---|---|
| ✅ Linux + CUDA 12.1-12.9 | Full GPU | 150-270x speedup | Recommended for large datasets |
| ❌ macOS (Intel/Apple Silicon) | CPU only | Baseline | No NVIDIA GPU support |
| ❌ Windows | CPU only | Baseline | Use WSL2 for GPU support |
GPU Requirements (Linux only):
- System CUDA 12.1-12.9 installed
- NVIDIA driver >= 525
- Compatible NVIDIA GPU
Quick Install
Linux (CPU Only)
Using pip:
pip install nlsq "jax[cpu]==0.8.0"
Using uv (recommended - faster):
uv pip install nlsq "jax[cpu]==0.8.0"
Linux (GPU Acceleration - Recommended) ⚡
Option 1: Automated Install (Recommended)
From the NLSQ repository:
git clone https://github.com/imewei/NLSQ.git
cd NLSQ
make install-jax-gpu # Handles uninstall, install, and verification
This single command:
- Detects your package manager (uv, conda/mamba, or pip)
- Uninstalls CPU-only JAX
- Installs GPU-enabled JAX with CUDA 12 support
- Verifies GPU detection automatically
Option 2: Manual Install (pip)
# Step 1: Uninstall CPU-only version
pip uninstall -y jax jaxlib
# Step 2: Install JAX with CUDA support (best performance)
pip install "jax[cuda12-local]==0.8.0"
# Step 3: Verify GPU detection
python -c "import jax; print('Devices:', jax.devices())"
# Expected: [cuda(id=0)] instead of [CpuDevice(id=0)]
Option 3: Manual Install (uv)
# Step 1: Uninstall CPU-only version
uv pip uninstall jax jaxlib
# Step 2: Install JAX with CUDA support
uv pip install "jax[cuda12-local]==0.8.0"
# Step 3: Verify GPU detection
python -c "import jax; print('Devices:', jax.devices())"
Alternative: For systems without CUDA installed, use bundled CUDA (larger download):
pip install "jax[cuda12]==0.8.0"
# or with uv:
uv pip install "jax[cuda12]==0.8.0"
Windows & macOS
# CPU only (GPU not supported natively)
pip install nlsq "jax[cpu]==0.8.0"
# or with uv:
uv pip install nlsq "jax[cpu]==0.8.0"
Windows GPU Users: Use WSL2 (Windows Subsystem for Linux) and follow the Linux GPU installation instructions above.
Development Installation
Using pip:
git clone https://github.com/imewei/NLSQ.git
cd NLSQ
pip install -e ".[dev,test,docs]"
Using uv (recommended - faster):
git clone https://github.com/imewei/NLSQ.git
cd NLSQ
uv pip install -e ".[dev,test,docs]"
For GPU support in development:
make install-jax-gpu
GPU Troubleshooting
Diagnostic Tools
Check your environment configuration:
# From NLSQ repository
make env-info # Show platform, package manager, GPU hardware, CUDA version
make gpu-check # Test JAX GPU detection
Common Issues
Issue 1: Warning "CUDA-enabled jaxlib is not installed"
Symptoms:
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed.
Falling back to cpu.
Solution:
# Verify GPU hardware
nvidia-smi # Should show your GPU
# Verify CUDA version
nvcc --version # Should show CUDA 12.1-12.9
# Reinstall JAX with GPU support
pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]==0.8.0"
# Verify fix
python -c "import jax; print(jax.devices())"
# Expected: [cuda(id=0)]
Issue 2: ImportError or "CUDA library not found"
Symptoms:
ImportError: libcudart.so.12: cannot open shared object file
Solution:
# Set CUDA library path (add to ~/.bashrc for permanent fix)
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Verify CUDA installation
ls /usr/local/cuda/lib64/libcudart.so*
Issue 3: Out of memory errors during computation
Symptoms:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED
Solution: Reduce GPU memory usage:
# Option 1: Reduce chunk size for large datasets
from nlsq import curve_fit_large
popt, pcov = curve_fit_large(func, x, y, memory_limit_gb=4.0) # Reduce from default
# Option 2: Configure JAX memory fraction
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.7" # Use 70% of GPU memory
Issue 4: Slow performance despite GPU
Symptoms: First run is slow, subsequent runs are fast
Explanation: This is normal! JAX uses JIT (Just-In-Time) compilation:
- First run: 450-650ms (includes compilation)
- Cached runs: 1.7-2.0ms (150-270x faster)
Solution: Use CurveFit class to reuse compilation:
from nlsq import CurveFit
# ... (assumes model_func, xdata1, ydata1, xdata2, ydata2 defined)
fitter = CurveFit(model_func)
popt1, pcov1 = fitter.fit(xdata1, ydata1) # First run: JIT compiles
popt2, pcov2 = fitter.fit(xdata2, ydata2) # Second run: reuses compilation
Issue 5: Suppressing GPU Acceleration Warnings
Symptoms: You see this warning on import even though you intentionally use CPU-only JAX:
⚠️ GPU ACCELERATION AVAILABLE
═══════════════════════════════
NVIDIA GPU detected: Tesla V100-SXM2-16GB
JAX is currently using: CPU-only
When you might want to suppress this:
- Running tests in CI/CD pipelines
- Using CPU-only JAX intentionally (testing, debugging, etc.)
- Parsing stdout programmatically
- Reducing output clutter in Jupyter notebooks
Solution: Set the NLSQ_SKIP_GPU_CHECK environment variable:
# Option 1: Set before running Python
export NLSQ_SKIP_GPU_CHECK=1
python your_script.py
# Option 2: Inline with command
NLSQ_SKIP_GPU_CHECK=1 python your_script.py
# Option 3: Add to CI/CD environment variables
# GitHub Actions example:
env:
NLSQ_SKIP_GPU_CHECK: "1"
# Option 4: Set in Python before importing nlsq
import os
os.environ['NLSQ_SKIP_GPU_CHECK'] = '1'
import nlsq # No warning printed
Accepted values: "1", "true", "yes" (case-insensitive)
Note: This suppresses the warning but does not affect actual GPU usage. If you have GPU-enabled JAX installed, it will still use the GPU for computations.
Conda/Mamba Users
NLSQ works seamlessly in conda environments using pip:
conda create -n nlsq python=3.12
conda activate nlsq
pip install nlsq
# For GPU (Linux only)
git clone https://github.com/imewei/NLSQ.git
cd NLSQ
make install-jax-gpu # Automatically detects conda/mamba
Note: Conda extras syntax (conda install nlsq[gpu-cuda]) is not supported. Use the Makefile or manual pip installation method above.
Citing NLSQ
If you use NLSQ in your research, please cite both the NLSQ software and the original JAXFit paper:
NLSQ Software Citation
@software{nlsq2024,
title={NLSQ: Nonlinear Least Squares Curve Fitting for GPU/TPU},
author={Chen, Wei and Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
year={2024},
url={https://github.com/imewei/NLSQ},
note={Enhanced fork of JAXFit with advanced features for large datasets, memory management, and algorithm selection}
}
Original JAXFit Paper
@article{jaxfit2022,
title={JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on the {GPU}},
author={Hofer, Lucas R and Krstaji{\'c}, Milan and Smith, Robert P},
journal={arXiv preprint arXiv:2208.12187},
year={2022},
url={https://doi.org/10.48550/arXiv.2208.12187}
}
Reference documentation
For details about the NLSQ API, see the reference documentation.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nlsq-0.3.8.tar.gz.
File metadata
- Download URL: nlsq-0.3.8.tar.gz
- Upload date:
- Size: 1.4 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
154af525f0e409282897fa8918db616999b45c362f72cdc63509917a304bce5d
|
|
| MD5 |
788f3413fabb3f1da9e0ba785617428b
|
|
| BLAKE2b-256 |
d03c4b8d78b1a6826e9e67765e4f57c0a2def1f757432cafc7536bfe3695415c
|
Provenance
The following attestation bundles were made for nlsq-0.3.8.tar.gz:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.3.8.tar.gz -
Subject digest:
154af525f0e409282897fa8918db616999b45c362f72cdc63509917a304bce5d - Sigstore transparency entry: 778721108
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@b07403b4eed2f573c5a63689cb5d8b3a3fc3358a -
Branch / Tag:
refs/tags/v0.3.8 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@b07403b4eed2f573c5a63689cb5d8b3a3fc3358a -
Trigger Event:
push
-
Statement type:
File details
Details for the file nlsq-0.3.8-py3-none-any.whl.
File metadata
- Download URL: nlsq-0.3.8-py3-none-any.whl
- Upload date:
- Size: 359.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
72e630d6f52908ea6906fd41561a17c07cbf3160706b80f43f7bdbd0b33e702f
|
|
| MD5 |
5e8abeb82057b87b9e0f7335af72d0ed
|
|
| BLAKE2b-256 |
97cb1c7a5abd4cc05d9e3db431b85b46769ca14f8dd1f4b7b52c2ebdcfcd35ac
|
Provenance
The following attestation bundles were made for nlsq-0.3.8-py3-none-any.whl:
Publisher:
release.yml on imewei/NLSQ
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nlsq-0.3.8-py3-none-any.whl -
Subject digest:
72e630d6f52908ea6906fd41561a17c07cbf3160706b80f43f7bdbd0b33e702f - Sigstore transparency entry: 778721111
- Sigstore integration time:
-
Permalink:
imewei/NLSQ@b07403b4eed2f573c5a63689cb5d8b3a3fc3358a -
Branch / Tag:
refs/tags/v0.3.8 - Owner: https://github.com/imewei
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@b07403b4eed2f573c5a63689cb5d8b3a3fc3358a -
Trigger Event:
push
-
Statement type: