Skip to main content

Differentiable QP solver in JAX.

Project description

qpax

Differentiable QP solver in JAX.

Paper

This package can be used for solving convex quadratic programs of the following form:

$$ \begin{align*} \underset{x}{\text{minimize}} & \quad \frac{1}{2}x^TQx + q^Tx \ \text{subject to} & \quad Ax = b, \ & \quad Gx \leq h \end{align*} $$

where $Q \succeq 0$. This solver can be combined with JAX's jit and vmap functionality, as well as differentiated with reverse-mode grad.

The QP is solved with a primal-dual interior point algorithm detailed in cvxgen, with the solution to the linear systems computed with reduction techniques from cvxopt. At an approximate primal-dual solution, the primal variable $x$ is differentiated with respect to the problem parameters using the implicit function theorem as shown in optnet, and their pytorch-based qp solver qpth.

Installation

To install directly from github using pip:

$ pip install qpax

Alternatively, to install from source in editable mode:

$ pip install -e .

Usage

🚨 Float32 Warning 🚨

The solver tolerance (solver_tol) should be something reasonable given the available precision. With 32bit precision (the default in JAX), solver_tol should be greater than 1e-5.

Precision Tolerance
jnp.float32 solver_tol$\in$ [1e-5, 1e-2]
jnp.float64 solver_tol$\in$ [1e-12, 1e-2]

In order to enable 64bit precision, you can do the following at startup:

# again, this only works on startup!
import jax
jax.config.update("jax_enable_x64", True)

This is taken from the JAX - The Sharp Bits.

Solving a QP

We can solve QPs with qpax in a way that plays nice with JAX's jit and vmap:

import qpax

# solve QP (this can be combined with jit or vmap)
x, s, z, y, converged, iters = qpax.solve_qp(Q, q, A, b, G, h, solver_tol=1e-6)

Linear System Solver

By default, qpax uses a Cholesky factorization to solve the internal linear systems. You can switch to QR factorization instead, which can be more numerically stable for ill-conditioned problems:

import qpax

# use QR factorization for the internal linear solves
x, s, z, y, converged, iters = qpax.solve_qp(
    Q, q, A, b, G, h,
    linear_solver=qpax.LinearSolver.QR,
)

Available options are qpax.LinearSolver.CHOLESKY (default) and qpax.LinearSolver.QR.

Solving a batch of QP's

Here let's solve a batch of nonnegative least squares problems as QPs. This outlines two bits of functionality from qpax, first is the ability to solve QPs without any equality constraints, and second is the ability to vmap over a batch of QPs.

import numpy as np
import jax 
import jax.numpy as jnp 
from jax import jit, grad, vmap  
import qpax 
import timeit

"""
solve batched non-negative least squares (nnls) problems
 
min_x    |Fx - g|^2 
st        x >= 0 
"""

n = 5   # size of x 
m = 10  # rows in F 

# create data for N_qps random nnls problems  
N_qps = 10000 
Fs = jnp.array(np.random.randn(N_qps, m, n))
gs = jnp.array(np.random.randn(N_qps, m))

@jit
def form_qp(F, g):
  # convert the least squares to qp form 
  n = F.shape[1]
  Q = F.T @ F 
  q = -F.T @ g 
  G = -jnp.eye(n)
  h = jnp.zeros(n)
  A = jnp.zeros((0, n))
  b = jnp.zeros(0)
  return Q, q, A, b, G, h

# create the QPs in a batched fashion 
Qs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)

# create function for solving a batch of QPs 
batch_qp = jit(vmap(qpax.solve_qp_primal, in_axes = (0, 0, 0, 0, 0, 0)))

xs = batch_qp(Qs, qs, As, bs, Gs, hs)

Differentiating a QP

Alternatively, if we are only looking to use the primal variable x, we can use solve_qp_primal which enables automatic differentiation:

import jax 
import jax.numpy as jnp 
import qpax 

def loss(Q, q, A, b, G, h):
    x = qpax.solve_qp_primal(Q, q, A, b, G, h, solver_tol=1e-4, target_kappa=1e-3) 
    x_bar = jnp.ones(len(q))
    return jnp.dot(x - x_bar, x - x_bar)
  
# gradient of loss function   
loss_grad = jax.grad(loss, argnums = (0, 1, 2, 3, 4, 5))

# compatible with jit 
loss_grad_jit = jax.jit(loss_grad)

# calculate derivatives 
derivs = loss_grad_jit(Q, q, A, b, G, h)
dl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs 

where target_kappa is used to determine how much smoothing should be applied to the gradients through solve_qp_primal. For more detail on target_kappa, please refer to the paper.

Citation

Paper

@misc{tracy2024differentiability,
    title={On the Differentiability of the Primal-Dual Interior-Point Method},
    author={Kevin Tracy and Zachary Manchester},
    year={2024},
    eprint={2406.11749},
    archivePrefix={arXiv},
    primaryClass={math.OC}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

qpax-0.0.11.tar.gz (71.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

qpax-0.0.11-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file qpax-0.0.11.tar.gz.

File metadata

  • Download URL: qpax-0.0.11.tar.gz
  • Upload date:
  • Size: 71.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for qpax-0.0.11.tar.gz
Algorithm Hash digest
SHA256 d2a3f1955f8b89f2af4deb994ce919ce2314eba39fcb172c4d2241117b42b094
MD5 bd3167e22bbea62ed7e8898ac4a4cb10
BLAKE2b-256 75481199aefd2fe8ebca89a20a1ce4cca487affc8caeba986147f07da5f3833e

See more details on using hashes here.

File details

Details for the file qpax-0.0.11-py3-none-any.whl.

File metadata

  • Download URL: qpax-0.0.11-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for qpax-0.0.11-py3-none-any.whl
Algorithm Hash digest
SHA256 ed51ea096a02369a0bed34a35db203238c0815f7c4e8b103ce64aab8d50551ee
MD5 252db653b656c92ecebc128c208e86a7
BLAKE2b-256 7ee86eecc52f10775377c805cc5475aef2ac79c58693e4900d870ae83b5a3afc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page