Sequential Least Squares Programming (SLSQP) optimizer implemented in pure JAX
Project description
slsqp-jax
A pure-JAX implementation of the SLSQP (Sequential Least Squares Quadratic Programming) algorithm for constrained nonlinear optimization, designed for moderate to large decision spaces (5,000–50,000 variables). All linear algebra is performed through JAX, so the solver runs natively on CPU, GPU, and TPU, and is fully compatible with jax.jit, jax.vmap, and jax.grad.
The SLSQP solver is built on top of Optimistix, a JAX library for nonlinear solvers. It implements the optimistix.AbstractMinimiser interface, so you run it through the standard optimistix.minimise entry point — no manual iteration loop required.
Installation
You can install the package from PyPI using any standard method:
pip
pip install slsqp-jax
uv
uv add "slsqp-jax"
pixi
pixi add --pypi slsqp-jax
Usage
Basic: objective and constraints
Define an objective function with signature (x, args) -> (scalar, aux) and optional constraint functions with signature (x, args) -> array. Equality constraints must satisfy c_eq(x) = 0 and inequality constraints must satisfy c_ineq(x) >= 0.
Then create an SLSQP solver and pass it to optimistix.minimise:
import jax.numpy as jnp
import optimistix as optx
from slsqp_jax import SLSQP
# Objective: minimize x^2 + y^2
def objective(x, args):
return jnp.sum(x**2), None
# Equality constraint: x + y = 1
def eq_constraint(x, args):
return jnp.array([x[0] + x[1] - 1.0])
# Inequality constraint: x >= 0.2
def ineq_constraint(x, args):
return jnp.array([x[0] - 0.2])
solver = SLSQP(
eq_constraint_fn=eq_constraint,
n_eq_constraints=1,
ineq_constraint_fn=ineq_constraint,
n_ineq_constraints=1,
rtol=1e-8,
atol=1e-8,
)
x0 = jnp.array([0.5, 0.5])
sol = optx.minimise(objective, solver, x0, has_aux=True, max_steps=100)
print(sol.value) # [0.2, 0.8]
print(sol.result) # RESULTS.successful
The returned optimistix.Solution object contains:
sol.value— the optimal point.sol.result— a status code (RESULTS.successfulor an error).sol.aux— any auxiliary data returned by the objective.sol.stats— solver statistics (e.g. number of steps taken).sol.state— the final internal solver state.
Since SLSQP is a standard Optimistix minimiser, it composes with all Optimistix features: throw=False for non-raising error handling, custom adjoint methods for differentiating through the solve, and so on. See the Optimistix documentation for details.
Supplying gradients and Jacobians
By default the solver computes the objective gradient via jax.grad and constraint Jacobians via jax.jacrev. You can supply your own functions to avoid redundant computation or to handle functions that are not reverse-mode differentiable:
solver = SLSQP(
eq_constraint_fn=eq_constraint,
n_eq_constraints=1,
# User-supplied gradient: (x, args) -> grad_f(x)
obj_grad_fn=my_grad_fn,
# User-supplied Jacobian: (x, args) -> J(x) shape (m, n)
eq_jac_fn=my_eq_jac_fn,
)
sol = optx.minimise(objective, solver, x0, has_aux=True, max_steps=100)
Supplying Hessian-vector products
For problems where you have access to exact second-order information, you can supply Hessian-vector product (HVP) functions. The solver uses these to produce high-quality secant pairs for the L-BFGS Hessian approximation, which typically improves convergence compared to using gradient differences alone:
# Objective HVP: (x, v, args) -> H_f(x) @ v
def obj_hvp(x, v, args):
return 2.0 * v # Hessian of sum(x^2) is 2*I
# Per-constraint HVP: (x, v, args) -> array of shape (m, n)
# Row i is H_{c_i}(x) @ v
def eq_hvp(x, v, args):
return jnp.zeros((1, x.shape[0])) # Linear constraint has zero Hessian
solver = SLSQP(
eq_constraint_fn=eq_constraint,
n_eq_constraints=1,
obj_hvp_fn=obj_hvp,
eq_hvp_fn=eq_hvp, # Optional — AD fallback is used if omitted
)
sol = optx.minimise(objective, solver, x0, has_aux=True, max_steps=100)
Note that you supply HVPs for the objective and constraint functions separately, not for the Lagrangian. The solver composes the Lagrangian HVP internally using the current KKT multipliers:
$$ \nabla^2 L(x) v = \nabla^2 f(x) v - \sum_i \lambda_i^{\text{eq}} \nabla^2 c_i^{\text{eq}}(x) v - \sum_j \mu_j^{\text{ineq}} \nabla^2 c_j^{\text{ineq}}(x) v $$
If you provide obj_hvp_fn but omit the constraint HVP functions, the solver automatically computes the missing constraint HVPs via forward-over-reverse AD on the scalar function $\lambda^T c(x)$, which costs one reverse pass plus one forward pass regardless of the number of constraints.
Frozen Hessian in the QP subproblem: The QP inner loop always uses a frozen L-BFGS approximation to the Lagrangian Hessian, even when exact HVPs are available. The exact HVP is called only once per main iteration (to probe along the step direction and produce an exact secant pair for the L-BFGS update). This design ensures (1) the QP subproblem sees a truly constant quadratic model, and (2) expensive HVP evaluations are not repeated thousands of times inside the projected CG solver.
Box constraints (bounds)
You can specify simple lower and upper bounds on decision variables using the bounds parameter:
import jax.numpy as jnp
bounds = jnp.array([
[0.0, 1.0], # 0 <= x_0 <= 1
[-jnp.inf, 5.0], # x_1 <= 5 (no lower bound)
[0.0, jnp.inf], # x_2 >= 0 (no upper bound)
])
solver = SLSQP(bounds=bounds)
Bounds play a dual role in the solver, following the projected-SQP methodology (Heinkenschloss & Ridzal, Projected Sequential Quadratic Programming Methods, SIAM J. Optim., 1996):
- QP inequality constraints — inside the QP subproblem, bounds are linearised as ordinary inequality constraints so the search direction is aware of the feasible box.
- Hard projection — after every line search step (and at initialisation), the iterate is projected (clipped) onto the feasible box. This guarantees that the objective and constraint functions are never evaluated outside the bounds, which is critical when those functions are undefined or ill-conditioned outside the box (e.g. a log-likelihood with positivity constraints on its parameters).
Algorithm
Overview
Each SLSQP iteration performs four steps:
- QP subproblem: Construct a quadratic approximation of the objective using the frozen L-BFGS Hessian and linearise the constraints around the current point. Solve the resulting QP to obtain a search direction.
- Line search: Use a Han-Powell L1 merit function $\phi(x;\rho) = f(x) + \rho (\lVert c_{\text{eq}}\rVert_1 + \lVert\max(0, -c_{\text{ineq}})\rVert_1)$ with backtracking Armijo conditions to determine the step size.
- Accept step: Update the iterate $x_{k+1} = x_k + \alpha d_k$.
- Hessian update: Append the new curvature pair $(s, y)$ to the L-BFGS history, where $y$ is either an exact HVP probe $\nabla^2 L(x_k) s$ (if HVP functions are provided) or the gradient difference $\nabla L(x_{k+1}) - \nabla L(x_k)$.
Scaling considerations: why L-BFGS over BFGS
Classical SLSQP (e.g. SciPy's implementation) maintains a dense $n \times n$ BFGS approximation to the Hessian of the Lagrangian. This requires $O(n^2)$ memory and $O(n^2)$ work per iteration for the matrix update alone. For the target problem sizes:
| n | Dense Hessian memory | L-BFGS memory (k=10) |
|---|---|---|
| 1,000 | 8 MB | 160 KB |
| 10,000 | 800 MB | 1.6 MB |
| 50,000 | 20 GB | 8 MB |
L-BFGS stores only the last $k$ step/gradient-difference pairs $(s_i, y_i)$ and computes Hessian-vector products in $O(kn)$ time using the compact representation (Byrd, Nocedal & Schnabel, 1994):
$$ B_k = \gamma I - W N^{-1} W^T $$
where $W = (\gamma S, Y)$ is the horizontal concatenation of matrices $\gamma S$ and $Y$, and $N$ is a small $2k \times 2k$ matrix built from inner products of the stored vectors. The $2k \times 2k$ system is solved directly — negligible cost for $k << n$.
Powell's damping is applied to each curvature pair before storage to ensure positive definiteness, which is essential for constrained problems where the standard curvature condition $s^T y > 0$ can fail.
Scaling considerations: why projected CG over dense KKT
Classical SLSQP solves the QP subproblem by forming and factorising the $(n + m) \times (n + m)$ dense KKT system at $O(n^3)$ cost. This implementation instead uses projected conjugate gradient (CG) inside an active-set loop:
- Projection: For active constraints with matrix $A$ ($m_{\text{active}} \times n$), define the null-space projector $P(v) = v - A^T (A A^T)^{-1} A v$. The $A A^T$ system is only $m_{\text{active}} \times m_{\text{active}}$ (tiny, since $m << n$) and is solved directly.
- CG in null space: Run conjugate gradient on the projected system, where each iteration requires one HVP ($O(kn)$ for L-BFGS) and one projection ($O(mn)$).
- Active-set outer loop: Add the most violated inequality or drop the most negative multiplier until KKT conditions are satisfied.
Total cost per QP solve: $O(n \cdot k \cdot t)$, where $t$ is the number of CG iterations (typically $t << n$). Compared to $O(n^3)$ for the dense approach, this is orders of magnitude faster for $n > 1000$.
Reverse-mode AD and while_loop
By default, the solver computes:
- Objective gradient via
jax.grad(reverse-mode). - Constraint Jacobians via
jax.jacrev(reverse-mode, $O(m)$ passes — faster thanjax.jacfwd's $O(n)$ passes when $m << n$). - HVP fallback via
jax.jvp(jax.grad(f), ...)(forward-over-reverse).
All of these require reverse-mode differentiation through the user's functions. JAX's reverse-mode AD does not support differentiating through jax.lax.while_loop or other variable-length control flow primitives. If your objective or constraint functions contain while_loop, scan with variable-length carries, or other non-reverse-differentiable operations, the AD fallback will fail.
How to handle this: supply your own derivative functions via the optional fields on SLSQP:
solver = SLSQP(
obj_grad_fn=my_custom_grad, # bypass jax.grad
eq_jac_fn=my_custom_eq_jac, # bypass jax.jacrev
ineq_jac_fn=my_custom_ineq_jac, # bypass jax.jacrev
obj_hvp_fn=my_custom_obj_hvp, # bypass forward-over-reverse
eq_hvp_fn=my_custom_eq_hvp, # bypass forward-over-reverse
ineq_hvp_fn=my_custom_ineq_hvp, # bypass forward-over-reverse
...
)
When all derivative functions are supplied, the solver never calls jax.grad, jax.jacrev, or jax.jvp on your functions, so while_loop and other control flow work without issue. Alternatively, if only the HVP functions are problematic, you can supply obj_grad_fn and the Jacobian functions (which only require first-order derivatives) and let the solver fall back to L-BFGS mode (no HVPs needed) by omitting obj_hvp_fn.
Convergence safeguards
The solver includes safeguards against premature termination, a known issue in SciPy's SLSQP where the optimizer can terminate after a single iteration when the initial point exactly satisfies equality constraints:
-
Minimum iterations (
min_steps, default 1): The solver will not declare convergence before completing at leastmin_stepsiterations. This ensures that KKT multipliers have been computed by at least one QP solve before checking KKT optimality conditions. -
Initial multiplier estimation: When equality constraints are present, the initial Lagrange multipliers are estimated via least-squares rather than being set to zero. This prevents the Lagrangian gradient from collapsing to the objective gradient at the first convergence check.
solver = SLSQP(
eq_constraint_fn=eq_constraint,
n_eq_constraints=1,
min_steps=1, # Default; set to 0 to allow convergence at step 0
)
QP anti-cycling: the EXPAND procedure
The QP subproblem is solved by a primal active-set method that adds or removes inequality constraints one at a time. When the problem is degenerate — multiple constraints pass through the same vertex or have tied violation/multiplier values — the active-set loop can cycle: iteration $i$ activates a constraint, iteration $i+1$ drops it, iteration $i+2$ re-activates it, and so on. The QP then exhausts its iteration budget without converging, producing a poor search direction.
This implementation uses the EXPAND procedure (Gill, Murray, Saunders & Wright, Mathematical Programming 45, 1989) to break such cycles. The idea is simple: instead of a fixed feasibility tolerance, the active-set loop maintains a working tolerance that grows monotonically:
$$ \delta_k = \texttt{tol} + k \cdot \tau, \qquad \tau = \frac{\texttt{tol} \cdot \texttt{expand_factor}}{\texttt{max_iter}} $$
At each active-set iteration $k$:
- A constraint is considered violated only if its residual is below $-\delta_k$ (progressively stricter threshold for activation).
- A multiplier is considered negative only if it is below $-\delta_k$ (progressively stricter threshold for deactivation).
Because $\delta_k$ increases at every step, marginally active or marginally infeasible constraints that cause cycling are gradually excluded, breaking the degeneracy. With the default settings (expand_factor=1.0, tol=1e-8, max_iter=100), the tolerance doubles from tol to 2·tol over the full iteration budget — conservative enough to preserve solution quality while reliably eliminating cycles.
EXPAND is the standard anti-cycling technique used in production solvers (MINOS, SNOPT, SQOPT) and is backed by a convergence guarantee: strict objective decrease within each expanding sequence. The expand_factor parameter on solve_qp controls the growth rate; set it to 0.0 to disable expansion entirely.
Outer-loop stagnation detection
Even with anti-cycling in the QP, the outer SLSQP loop can fail to make progress — for example, when the problem is infeasible, highly degenerate, or the QP solution is of poor quality. The solver detects this by tracking the L1 merit function across iterations:
$$ \phi(x; \rho) = f(x) + \rho \bigl(\lVert c_{\text{eq}}(x) \rVert_1 + \lVert \max(0, -c_{\text{ineq}}(x)) \rVert_1\bigr) $$
After each step, the relative improvement in the merit value is computed:
$$ \text{rel_improvement} = \frac{|\phi_{k-1} - \phi_k|}{\max(|\phi_{k-1}|,, 1)} $$
If this falls below stagnation_tol for stagnation_patience consecutive iterations, the solver terminates early with a nonlinear_divergence result code rather than running until max_steps. This avoids wasting computation on a problem the solver cannot solve.
solver = SLSQP(
eq_constraint_fn=eq_constraint,
n_eq_constraints=1,
stagnation_tol=1e-12, # Minimum relative merit improvement
stagnation_patience=5, # Consecutive stagnant steps before failure
)
The stagnation count and last step size are included in sol.stats for diagnostics:
sol = optx.minimise(objective, solver, x0, has_aux=True, max_steps=100, throw=False)
print(sol.stats["stagnation_count"]) # 0 if converged normally
print(sol.stats["last_step_size"]) # Step size from final iteration
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file slsqp_jax-0.3.0.tar.gz.
File metadata
- Download URL: slsqp_jax-0.3.0.tar.gz
- Upload date:
- Size: 633.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b813cd19eb3eeaf12cf95e03e917d80eac73b0961ae7e2341352fd6c8458225b
|
|
| MD5 |
e352b20039ea5b37509a6fa253ba651d
|
|
| BLAKE2b-256 |
4982d3778af4c321d0c45fb2b3fe1e62c7cdda0df851bf2b4a440699189eb741
|
Provenance
The following attestation bundles were made for slsqp_jax-0.3.0.tar.gz:
Publisher:
release.yml on lucianopaz/slsqp-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
slsqp_jax-0.3.0.tar.gz -
Subject digest:
b813cd19eb3eeaf12cf95e03e917d80eac73b0961ae7e2341352fd6c8458225b - Sigstore transparency entry: 995337629
- Sigstore integration time:
-
Permalink:
lucianopaz/slsqp-jax@80bab92afd073ca2e4005688458759adfcb1d5cb -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/lucianopaz
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@80bab92afd073ca2e4005688458759adfcb1d5cb -
Trigger Event:
release
-
Statement type:
File details
Details for the file slsqp_jax-0.3.0-py3-none-any.whl.
File metadata
- Download URL: slsqp_jax-0.3.0-py3-none-any.whl
- Upload date:
- Size: 38.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
43674a79fb454db06645903cdb54fcdc3e2d6d7698a1f56d86259fc069720be7
|
|
| MD5 |
6aa09d260d637751f023efeb84381b4b
|
|
| BLAKE2b-256 |
55ce4ada41af603eff84c46253869c6a693652f66d79c780dee2916fae8ffe90
|
Provenance
The following attestation bundles were made for slsqp_jax-0.3.0-py3-none-any.whl:
Publisher:
release.yml on lucianopaz/slsqp-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
slsqp_jax-0.3.0-py3-none-any.whl -
Subject digest:
43674a79fb454db06645903cdb54fcdc3e2d6d7698a1f56d86259fc069720be7 - Sigstore transparency entry: 995337659
- Sigstore integration time:
-
Permalink:
lucianopaz/slsqp-jax@80bab92afd073ca2e4005688458759adfcb1d5cb -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/lucianopaz
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@80bab92afd073ca2e4005688458759adfcb1d5cb -
Trigger Event:
release
-
Statement type: