Skip to main content

Sequential Least Squares Programming (SLSQP) optimizer implemented in pure JAX

Project description

slsqp-jax

Build Documentation codecov License: MIT PyPI - Version Open In Colab

A pure-JAX implementation of the SLSQP (Sequential Least Squares Quadratic Programming) algorithm for constrained nonlinear optimization, designed for moderate to large decision spaces (5,000–50,000 variables). All linear algebra is performed through JAX, so the solver runs natively on CPU, GPU, and TPU, and is fully compatible with jax.jit, jax.vmap, and jax.grad.

The SLSQP solver is built on top of Optimistix, a JAX library for nonlinear solvers. It implements the optimistix.AbstractMinimiser interface, so you run it through the standard optimistix.minimise entry point — no manual iteration loop required.

Installation

You can install the package from PyPI using any standard method:

pip

pip install slsqp-jax

uv

uv add "slsqp-jax

pixi

pixi add --pypi slsqp-jax

Usage

Basic: objective and constraints

Define an objective function with signature (x, args) -> (scalar, aux) and optional constraint functions with signature (x, args) -> array. Equality constraints must satisfy c_eq(x) = 0 and inequality constraints must satisfy c_ineq(x) >= 0.

Then create an SLSQP solver and pass it to optimistix.minimise:

import jax.numpy as jnp
import optimistix as optx
from slsqp_jax import SLSQP

# Objective: minimize x^2 + y^2
def objective(x, args):
    return jnp.sum(x**2), None

# Equality constraint: x + y = 1
def eq_constraint(x, args):
    return jnp.array([x[0] + x[1] - 1.0])

# Inequality constraint: x >= 0.2
def ineq_constraint(x, args):
    return jnp.array([x[0] - 0.2])

solver = SLSQP(
    eq_constraint_fn=eq_constraint,
    n_eq_constraints=1,
    ineq_constraint_fn=ineq_constraint,
    n_ineq_constraints=1,
    rtol=1e-8,
    atol=1e-8,
)

x0 = jnp.array([0.5, 0.5])
sol = optx.minimise(objective, solver, x0, has_aux=True, max_steps=100)

print(sol.value)   # [0.2, 0.8]
print(sol.result)  # RESULTS.successful

The returned optimistix.Solution object contains:

  • sol.value — the optimal point.
  • sol.result — a status code (RESULTS.successful or an error).
  • sol.aux — any auxiliary data returned by the objective.
  • sol.stats — solver statistics (e.g. number of steps taken).
  • sol.state — the final internal solver state.

Since SLSQP is a standard Optimistix minimiser, it composes with all Optimistix features: throw=False for non-raising error handling, custom adjoint methods for differentiating through the solve, and so on. See the Optimistix documentation for details.

Supplying gradients and Jacobians

By default the solver computes the objective gradient via jax.grad and constraint Jacobians via jax.jacrev. You can supply your own functions to avoid redundant computation or to handle functions that are not reverse-mode differentiable:

solver = SLSQP(
    eq_constraint_fn=eq_constraint,
    n_eq_constraints=1,
    # User-supplied gradient: (x, args) -> grad_f(x)
    obj_grad_fn=my_grad_fn,
    # User-supplied Jacobian: (x, args) -> J(x)  shape (m, n)
    eq_jac_fn=my_eq_jac_fn,
)
sol = optx.minimise(objective, solver, x0, has_aux=True, max_steps=100)

Supplying Hessian-vector products

For problems where you have access to exact second-order information, you can supply Hessian-vector product (HVP) functions. The solver uses these to produce high-quality secant pairs for the L-BFGS Hessian approximation, which typically improves convergence compared to using gradient differences alone:

# Objective HVP: (x, v, args) -> H_f(x) @ v
def obj_hvp(x, v, args):
    return 2.0 * v  # Hessian of sum(x^2) is 2*I

# Per-constraint HVP: (x, v, args) -> array of shape (m, n)
# Row i is H_{c_i}(x) @ v
def eq_hvp(x, v, args):
    return jnp.zeros((1, x.shape[0]))  # Linear constraint has zero Hessian

solver = SLSQP(
    eq_constraint_fn=eq_constraint,
    n_eq_constraints=1,
    obj_hvp_fn=obj_hvp,
    eq_hvp_fn=eq_hvp,  # Optional — AD fallback is used if omitted
)
sol = optx.minimise(objective, solver, x0, has_aux=True, max_steps=100)

Note that you supply HVPs for the objective and constraint functions separately, not for the Lagrangian. The solver composes the Lagrangian HVP internally using the current KKT multipliers:

$$ \nabla^2 L(x) v = \nabla^2 f(x) v - \sum_i \lambda_i^{\text{eq}} \nabla^2 c_i^{\text{eq}}(x) v - \sum_j \mu_j^{\text{ineq}} \nabla^2 c_j^{\text{ineq}}(x) v $$

If you provide obj_hvp_fn but omit the constraint HVP functions, the solver automatically computes the missing constraint HVPs via forward-over-reverse AD on the scalar function $\lambda^T c(x)$, which costs one reverse pass plus one forward pass regardless of the number of constraints.

Frozen Hessian in the QP subproblem: The QP inner loop always uses a frozen L-BFGS approximation to the Lagrangian Hessian, even when exact HVPs are available. The exact HVP is called only once per main iteration (to probe along the step direction and produce an exact secant pair for the L-BFGS update). This design ensures (1) the QP subproblem sees a truly constant quadratic model, and (2) expensive HVP evaluations are not repeated thousands of times inside the projected CG solver.

Algorithm

Overview

Each SLSQP iteration performs four steps:

  1. QP subproblem: Construct a quadratic approximation of the objective using the frozen L-BFGS Hessian and linearise the constraints around the current point. Solve the resulting QP to obtain a search direction.
  2. Line search: Use a Han-Powell L1 merit function $\phi(x;\rho) = f(x) + \rho (\lVert c_{\text{eq}}\rVert_1 + \lVert\max(0, -c_{\text{ineq}})\rVert_1)$ with backtracking Armijo conditions to determine the step size.
  3. Accept step: Update the iterate $x_{k+1} = x_k + \alpha d_k$.
  4. Hessian update: Append the new curvature pair $(s, y)$ to the L-BFGS history, where $y$ is either an exact HVP probe $\nabla^2 L(x_k) s$ (if HVP functions are provided) or the gradient difference $\nabla L(x_{k+1}) - \nabla L(x_k)$.

Scaling considerations: why L-BFGS over BFGS

Classical SLSQP (e.g. SciPy's implementation) maintains a dense $n \times n$ BFGS approximation to the Hessian of the Lagrangian. This requires $O(n^2)$ memory and $O(n^2)$ work per iteration for the matrix update alone. For the target problem sizes:

n Dense Hessian memory L-BFGS memory (k=10)
1,000 8 MB 160 KB
10,000 800 MB 1.6 MB
50,000 20 GB 8 MB

L-BFGS stores only the last $k$ step/gradient-difference pairs $(s_i, y_i)$ and computes Hessian-vector products in $O(kn)$ time using the compact representation (Byrd, Nocedal & Schnabel, 1994):

$$ B_k = \gamma I - W N^{-1} W^T $$

where $W = (\gamma S, Y)$ is the horizontal concatenation of matrices $\gamma S$ and $Y$, and $N$ is a small $2k \times 2k$ matrix built from inner products of the stored vectors. The $2k \times 2k$ system is solved directly — negligible cost for $k << n$.

Powell's damping is applied to each curvature pair before storage to ensure positive definiteness, which is essential for constrained problems where the standard curvature condition $s^T y > 0$ can fail.

Scaling considerations: why projected CG over dense KKT

Classical SLSQP solves the QP subproblem by forming and factorising the $(n + m) \times (n + m)$ dense KKT system at $O(n^3)$ cost. This implementation instead uses projected conjugate gradient (CG) inside an active-set loop:

  1. Projection: For active constraints with matrix $A$ ($m_{\text{active}} \times n$), define the null-space projector $P(v) = v - A^T (A A^T)^{-1} A v$. The $A A^T$ system is only $m_{\text{active}} \times m_{\text{active}}$ (tiny, since $m << n$) and is solved directly.
  2. CG in null space: Run conjugate gradient on the projected system, where each iteration requires one HVP ($O(kn)$ for L-BFGS) and one projection ($O(mn)$).
  3. Active-set outer loop: Add the most violated inequality or drop the most negative multiplier until KKT conditions are satisfied.

Total cost per QP solve: $O(n \cdot k \cdot t)$, where $t$ is the number of CG iterations (typically $t << n$). Compared to $O(n^3)$ for the dense approach, this is orders of magnitude faster for $n > 1000$.

Reverse-mode AD and while_loop

By default, the solver computes:

  • Objective gradient via jax.grad (reverse-mode).
  • Constraint Jacobians via jax.jacrev (reverse-mode, $O(m)$ passes — faster than jax.jacfwd's $O(n)$ passes when $m << n$).
  • HVP fallback via jax.jvp(jax.grad(f), ...) (forward-over-reverse).

All of these require reverse-mode differentiation through the user's functions. JAX's reverse-mode AD does not support differentiating through jax.lax.while_loop or other variable-length control flow primitives. If your objective or constraint functions contain while_loop, scan with variable-length carries, or other non-reverse-differentiable operations, the AD fallback will fail.

How to handle this: supply your own derivative functions via the optional fields on SLSQP:

solver = SLSQP(
    obj_grad_fn=my_custom_grad,        # bypass jax.grad
    eq_jac_fn=my_custom_eq_jac,        # bypass jax.jacrev
    ineq_jac_fn=my_custom_ineq_jac,    # bypass jax.jacrev
    obj_hvp_fn=my_custom_obj_hvp,      # bypass forward-over-reverse
    eq_hvp_fn=my_custom_eq_hvp,        # bypass forward-over-reverse
    ineq_hvp_fn=my_custom_ineq_hvp,    # bypass forward-over-reverse
    ...
)

When all derivative functions are supplied, the solver never calls jax.grad, jax.jacrev, or jax.jvp on your functions, so while_loop and other control flow work without issue. Alternatively, if only the HVP functions are problematic, you can supply obj_grad_fn and the Jacobian functions (which only require first-order derivatives) and let the solver fall back to L-BFGS mode (no HVPs needed) by omitting obj_hvp_fn.

Convergence safeguards

The solver includes safeguards against premature termination, a known issue in SciPy's SLSQP where the optimizer can terminate after a single iteration when the initial point exactly satisfies equality constraints:

  • Minimum iterations (min_steps, default 1): The solver will not declare convergence before completing at least min_steps iterations. This ensures that KKT multipliers have been computed by at least one QP solve before checking KKT optimality conditions.

  • Initial multiplier estimation: When equality constraints are present, the initial Lagrange multipliers are estimated via least-squares rather than being set to zero. This prevents the Lagrangian gradient from collapsing to the objective gradient at the first convergence check.

solver = SLSQP(
    eq_constraint_fn=eq_constraint,
    n_eq_constraints=1,
    min_steps=1,  # Default; set to 0 to allow convergence at step 0
)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

slsqp_jax-0.2.0.tar.gz (626.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

slsqp_jax-0.2.0-py3-none-any.whl (34.2 kB view details)

Uploaded Python 3

File details

Details for the file slsqp_jax-0.2.0.tar.gz.

File metadata

  • Download URL: slsqp_jax-0.2.0.tar.gz
  • Upload date:
  • Size: 626.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for slsqp_jax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 eaebe768ea19695ca5f0d47a3bd89966929e856fbbdf7d95c38ce0fb0d6427d7
MD5 c1f0b026167a8ca2f226c4fd7ea927a9
BLAKE2b-256 86fd1bb600fce850ba1294fb1ca1b8f840894a9f50454e1693b0986f5607e8d4

See more details on using hashes here.

Provenance

The following attestation bundles were made for slsqp_jax-0.2.0.tar.gz:

Publisher: release.yml on lucianopaz/slsqp-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file slsqp_jax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: slsqp_jax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 34.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for slsqp_jax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7ddc92ca38244216b6ed2ad5aa96c33b206ce6ea54bf708333c46a23819dc05f
MD5 3a8d5b8daf226762287ab0c00f88c06f
BLAKE2b-256 fbf78f4fb1a8e2c8c273d592d9f2be4f62d3c428e85aca625c7133536ff8cbe5

See more details on using hashes here.

Provenance

The following attestation bundles were made for slsqp_jax-0.2.0-py3-none-any.whl:

Publisher: release.yml on lucianopaz/slsqp-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page