MPAX: Mathematical Programming in JAX
Project description
MPAX: Mathematical Programming in JAX
MPAX is a hardware-accelerated, differentiable, batchable, and distributable solver for mathematical programming in JAX, designed to integrate with modern computational and deep learning workflows:
- Hardware accelerated: executes on multiple architectures including CPUs, GPUs and TPUs.
- Differentiable: computes derivatives of solutions with respect to inputs through implicit or unrolled differentiation.
- Batchable: solves multiple problem instances of the same shape simultaneously.
- Distributed: executes in parallel across multiple devices, such as several GPUs.
MPAX's primary motivation is to integrate mathematical programming with deep learning pipelines. To achieve this, MPAX aligns its algorithms and implementations with the requirements of deep learning hardware, ensuring compatibility with GPUs and TPUs. By being differentiable, MPAX can integrate directly into the backpropagation process of neural network training. Its batchability and distributability further enable scalable deployment in large-scale applications.
Currently, MPAX supports linear programming (LP) and quadratic programming (QP), the foundational problems in mathematical programming. Future releases will expand support to include other problem classes of mathematical programming.
Installation
You can install the latest released version of MPAX from PyPI via:
pip install mpax
or you can install the latest development version from GitHub:
pip install git+https://github.com/MIT-Lu-Lab/mpax.git
Quickstart
Currently, MPAX focuses on solving linear programming (LP) and quadratic programming (QP) problems of the following form:
\begin{equation}
\tag{LP}
\begin{aligned}
\min_{l \leq x \leq u}\ & c^\top x \\
\text{s.t.}\ & A x = b \\
& Gx \geq h
\end{aligned}
\end{equation}
\begin{equation}
\tag{QP}
\begin{aligned}
\min_{l \leq x \leq u}\ & \frac{1}{2} x^\top Q x + c^\top x \\
\text{s.t.}\ & A x = b \\
& Gx \geq h
\end{aligned}
\end{equation}
MPAX implements two state-of-the-art first-order methods:
- $\boldsymbol{\mathrm{ra}}$PDHG: restarted average Primal-Dual Hybrid Gradient, supporting both LP (paper) and QP (paper).
- $\boldsymbol{\mathrm{r^2}}$HPDHG: reflected restarted Halpern Primal-Dual Hybrid Gradient, supporting LP only (paper).
Solving a Single LP/QP Problem
MPAX supports both dense and sparse formats for the constraint matrix, controlled by the use_sparse_matrix
parameter.
from mpax import create_lp, create_qp, raPDHG, r2HPDHG
# Create LP using sparse matrix format (default)
lp = create_lp(c, A, b, G, h, l, u) # use_sparse_matrix=True by default
# Create LP using dense matrix format
lp = create_lp(c, A, b, G, h, l, u, use_sparse_matrix=False)
solver = r2HPDHG(eps_abs=1e-4, eps_rel=1e-4, verbose=True)
result = solver.optimize(lp)
# Or create QP
qp = create_qp(Q, c, A, b, G, h, l, u)
qp = create_qp(Q, c, A, b, G, h, l, u, use_sparse_matrix=False)
solver = raPDHG(eps_abs=1e-4, eps_rel=1e-4, verbose=True)
result = solver.optimize(lp)
Batch solving
Batch solving allows you to solve multiple LP problems of the same shape simultaneously by using jax.vmap
:
import jax.numpy as jnp
from mpax import create_lp, r2HPDHG
def single_optimize(c_vector):
lp = create_lp(c_vector, A, b, G, h, l, u)
solver = r2HPDHG(eps_abs=1e-4, eps_rel=1e-4, verbose=True)
result = solver.optimize(lp)
obj = jnp.dot(c_vector, result.primal_solution)
return result.primal_solution, obj
batch_size = 100
batch_c = jnp.tile(c, (batch_size, 1))
batch_optimize = jax.vmap(single_optimize)
result = batch_optimize(batch_c)
Device parallelism
Distribute computations across devices using JAX’s sharding capabilities:
import jax
from mpax import create_lp
# Data sharding
mesh = jax.make_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x',))
A_sharded = jax.device_put(A, sharding)
lp_sharded = create_lp(c, A_sharded, b, G, h, l, u)
solver = r2HPDHG(eps_abs=1e-4, eps_rel=1e-4, verbose=True)
jit_optimize = jax.jit(solver.optimize)
result = jit_optimize(lp_sharded)
Differentiation
An Example of computing the forward and backward passes of the "Smart Predict-then-Optimize+" loss using MPAX and jax.custom_jvp()
.
import jax
import jax.numpy as jnp
@jax.custom_vjp
def pso_fun(pred_cost, true_cost, true_sol, true_obj):
sol, obj = batch_optimize(2*pred_cost - true_cost)
loss = -obj + 2 * jnp.sum(pred_cost * true_sol, axis=1) - true_obj
loss = jnp.mean(loss)
return loss, sol
def spo_fwd(pred_cost, true_cost, true_sol, true_obj):
loss, sol = pso_fun(pred_cost, true_cost, true_sol, true_obj)
return loss, (sol, true_sol)
def spo_bwd(res, g):
sol, true_sol = res
grad = 2 * (true_sol - sol)
# No gradients needed for true_cost, true_sol, or true_obj
return grad * g, None, None, None
pso_fun.defvjp(spo_fwd, spo_bwd)
Solver Options
General options
Parameter | Type | Default | Description |
---|---|---|---|
verbose |
bool | False |
Enables detailed logging of the solver's progress. |
debug |
bool | False |
Activates additional debugging information. |
display_frequency |
int | 10 |
Frequency (in every termination check) for displaying solver statistics. |
jit |
bool | True |
Enables JIT (Just-In-Time) compilation for faster execution. |
unroll |
bool | False |
Unrolls iteration loops |
warm_start |
bool | False |
Whether to perform warm starting |
feasibility_polishing |
bool | False |
Whether to perform feasibility polishing |
Termination
Parameter | Type | Default | Description |
---|---|---|---|
eps_abs |
float | 1e-4 |
Absolute tolerance for convergence. |
eps_rel |
float | 1e-4 |
Relative tolerance for convergence. |
eps_primal_infeasible |
float | 1e-8 |
Tolerance for detecting primal infeasibility. |
eps_dual_infeasible |
float | 1e-8 |
Tolerance for detecting dual infeasibility |
eps_feas_polish |
float | 1e-6 |
Tolerance for feasibility polishing |
iteration_limit |
int | max_int |
Maximum number of iterations allowed (interpreted as unlimited by default) |
Precision
By default, MPAX uses single-precision (32-bit). To enable double-precision (64-bit), add the following at the start of your script:
jax.config.update("jax_enable_x64", True)
Determinism
Floating-point computations on GPUs in JAX may produce non-deterministic results. To ensure deterministic results, set:
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
Important: If you are using batch solving, do not enable --xla_gpu_deterministic_ops=true
, as it can significantly degrade performance.
Citation
If MPAX is useful or relevant to your research, please kindly recognize our contributions by citing our paper:
@article{lu2024mpax,
title={MPAX: Mathematical Programming in JAX},
author={Lu, Haihao and Peng, Zedong and Yang, Jinwen},
journal={arXiv preprint arXiv:2412.09734},
year={2024}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mpax-0.2.4.tar.gz
.
File metadata
- Download URL: mpax-0.2.4.tar.gz
- Upload date:
- Size: 45.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.8 Linux/6.11.0-1012-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
b34f11168d5fd88075cd0f7082b94f58ff36ae0508bea6bb5ca908d3b08b0922
|
|
MD5 |
936093c3a724445000bd2da2d5a278fb
|
|
BLAKE2b-256 |
ed0ddc4ba58a64b2c104bae09383c9e17eca052f24410a0d9155c75ff0d97837
|
File details
Details for the file mpax-0.2.4-py3-none-any.whl
.
File metadata
- Download URL: mpax-0.2.4-py3-none-any.whl
- Upload date:
- Size: 48.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.8 Linux/6.11.0-1012-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
85a3df11f3ba7588b0f51990754db880f2ccd73f15081d1fddc59798abad191d
|
|
MD5 |
b156c2ed77bf308fd396cd2c0e08c20b
|
|
BLAKE2b-256 |
43842681ace1aaebf78762f46ebab54ff3d603aee39817cd68c5a19f699165f5
|