A differentiable orthogonal projection layer for training hard-constrained neural networks.
Project description
Πnet: Optimizing hard-constrained neural networks with orthogonal projection layers
This repository contains a JAX implementation of Πnet, an output layer for neural networks that ensures the satisfaction of specified convex constraints.
[!NOTE] TL;DR Πnet leverages operator splitting for rapid and reliable projections in the forward pass, and the implicit function theorem for backpropagation. It offers a feasible-by-design optimization proxy for parametric constrained optimization problems to obtain modest-accuracy solutions faster than traditional solvers when solving a single problem, and significantly faster for a batch of problems.
Index
Getting started
To install Πnet, run:
- CPU-only (Linux/macOS/Windows)
pip install pinet-hcnn
- GPU (NVIDIA, CUDA 12)
pip install "pinet-hcnn[cuda12]"
[!WARNING] CUDA dependencies If you have issues with CUDA drivers, please follow the official instructions for cuda12 and cudnn (Note: wheels only available on linux). If you have issues with conflicting CUDA libraries, check also this issue... or use our Docker container 🤗.
We also provide a working Docker image to reproduce the results of the paper and to build on top.
docker compose run --rm pinet-cpu # Run the pytests on CPU
docker compose run --rm pinet-gpu # Run the pytests on GPU
[!WARNING] CUDA dependencies Running the Docker container with GPU support requires NVIDIA Container Toolkit on the host.
See also the section on reproducing the paper's results for more examples of commands.
Supported platforms 💻
| Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
|---|---|---|---|---|---|
| CPU | ✅ | ✅ | ✅ | ✅ | ✅ |
| NVIDIA GPU | ✅ | ✅ | n/a | ❌ | ❌ |
Examples
Constraints & Projection Layer
All tensors are batched. Let B = batch size (you may use B=1 to broadcast across a batch).
- Vectors: shape
(B, n, 1) - Matrices: shape
(B, n, d)
EqualityConstraint — enforce A @ x == b
import jax.numpy as jnp
from pinet import EqualityConstraint
B, n_eq, d = 4, 3, 5
A = jnp.zeros((1, n_eq, d)) # (1, n_eq, d) # broadcast across batch
b = jnp.zeros((B, n_eq, 1)) # (B, n_eq, 1)
eq = EqualityConstraint(
A=A,
b=b,
method=None, # let Project decide / lift later
var_b=True, # b provided per-batch at runtime
var_A=False, # A constant (broadcasted)
)
[!WARNING]
method=Noneeq.project()is only available ifmethod="pinv". When you have multiple constraints and you plan on using the equality constraint only within the projection layer, you can leavemethod=None(as above).
AffineInequalityConstraint — enforce lb ≤ C @ x ≤ ub
import jax.numpy as jnp
from pinet import AffineInequalityConstraint
n_ineq = 7
C = jnp.zeros((1, n_ineq, d)) # (1, n_ineq, d)
lb = jnp.full((B, n_ineq, 1), -1.0) # (B, n_ineq, 1)
ub = jnp.full((B, n_ineq, 1), 1.0) # (B, n_ineq, 1)
ineq = AffineInequalityConstraint(C=C, lb=lb, ub=ub)
[!WARNING]
ineq.project()intentionallyNotImplementedTo improve the efficiency of the projection, we always "lift" the affine inequality constraints as described in the paper. For this, we did not even bother implementing the projection method for this type of constraints 🤗.
BoxConstraint — clip selected dimensions
import numpy as np
import jax.numpy as jnp
from pinet import BoxConstraint, BoxConstraintSpecification
lb_x = jnp.full((B, d, 1), -2.0) # (B, d, 1)
ub_x = jnp.full((B, d, 1), 2.0) # (B, d, 1)
mask = np.ones(d, dtype=bool) # apply to all dims (use False to skip dims)
box = BoxConstraint(BoxConstraintSpecification(lb=lb_x, ub=ub_x, mask=mask))
# box.project(...) clips x[:, mask, :] into [lb_x, ub_x].
Combine constraints with Project (Douglas–Rachford)
Project handles:
- Lifting inequalities into equalities + auxiliary variables;
- Optional Ruiz equilibration;
- JIT-compiled forward;
- Optional custom VJP for backprop.
from pinet.project import Project
from pinet.dataclasses import ProjectionInstance
import jax.numpy as jnp
proj = Project(
eq_constraint=eq, # can be None
ineq_constraint=ineq, # can be None
box_constraint=box, # can be None
unroll=False, # use custom VJP path by default
)
# Build a ProjectionInstance with the point to project and (optionally) runtime specs:
x0 = jnp.zeros((B, d, 1))
yraw = ProjectionInstance(x=x0)
# If var_b=True and you supply per-batch b at runtime, pass it via your dataclass, e.g.:
# yraw = yraw.update(eq=yraw.eq.update(b=b))
y, sK = proj.call( # JIT-compiled projector
yraw=yraw,
n_iter=50, # Douglas-Rachford iterations
n_iter_backward=100, # Maximum number of iterations for the bicgstab algorithm
sigma=1.0, omega=1.7,
)
# If you want to resume the projection with the latest governing sequence sK,
# you can provided to the call method via s0=sK.
cv = proj.cv(y) # (B, 1, 1) max violation across constraints
# The CV can also be assessed for the different constraints separately,
# e.g., eq.cv(y), if eq is a constraint for y
# (shapes need to match, so be careful of lifting!)
Notes
- Batch rules: For each pair of tensors
(X, Y), either batch sizes match or one is1(broadcast). - Equality
method: Usemethod="pinv"when you rely on the equality projector standalone. When used insideProject, you can keepmethod=None; lifting will set up the pseudo-inverse internally. - Dimensions after lifting: If inequalities are present, the internal lifted dimension is
d + n_ineq(auxiliary variables).
Minimal “Toy MPC” Application
The helper below wires the projector into a Pinet model; the loss is your batched objective.
# benchmarks/toy_MPC/model.py
import jax.numpy as jnp
from flax import linen as nn
from pinet import BoxConstraint, BoxConstraintSpecification, EqualityConstraint
from src.benchmarks.model import build_model_and_train_step, setup_pinet
def setup_model(rng_key, hyperparameters, A, X, b, lb, ub, batched_objective):
activation = getattr(nn, hyperparameters["activation"])
if activation is None:
raise ValueError(f"Unknown activation: {hyperparameters['activation']}")
# Constraints (b varies at runtime; A is constant & broadcasted)
eq = EqualityConstraint(A=A, b=b, method=None, var_b=True)
box = BoxConstraint(BoxConstraintSpecification(lb=lb, ub=ub))
project, project_test, _ = setup_pinet(eq_constraint=eq, box_constraint=box,
hyperparameters=hyperparameters)
model, params, train_step = build_model_and_train_step(
rng_key=rng_key,
dim=A.shape[2],
features_list=hyperparameters["features_list"],
activation=activation,
project=project, # projector in the training graph
project_test=project_test, # projector used at eval
raw_train=hyperparameters.get("raw_train", False),
raw_test=hyperparameters.get("raw_test", False),
loss_fn=lambda preds, _b: batched_objective(preds),
example_x=X[:1, :, 0],
example_b=b[:1],
jit=True,
)
return model, params, train_step
Run the end-to-end script
To reproduce the results in the paper, you can run
python -m src.benchmarks.toy_MPC.run_toy_MPC --filename toy_MPC_seed42_examples10000.npz --config toy_MPC --seed 12
To generate the dataset, run
TODO
You’ll get:
- Training logs (loss, CV, timing),
- Validation/Test metrics incl. relative suboptimality & CV,
- Saved params & results ready to reload and plot trajectories.
[!TIP] Troubleshooting All the objects in
pinet.dataclassesoffer avalidatemethods, which can be used to verify your inputs.
Works using Πnet ⚙️
We collect here applications using Πnet. Please feel free to open a pull request to add yours! 🤗
| Link | Project |
|---|---|
| Multi-vehicle trajectory optimization with non-convex preferences This project features contexts dimensions in the millions and tens of thousands of optimization variables. |
Contributing ☕️
Contributions are more than welcome! 🙏 Please check out our contributing page, and feel free to open an issue for problems and feature requests⚠️.
Benchmarks 📈
Below, we summarize the performance gains of Πnet over state-of-the-art methods. We consider the following metrics:
- Relative Suboptimality ($\texttt{RS}$): The suboptimality of a candidate solution $\hat{y}$ compared to the optimal objective $J(y^{\star})$, computed by a high-accuracy solver.
- Constraint Violation ($\texttt{CV}$): Maximum violation ($\infty$-norm) of any constraint (equality and inequality). In practice, any solver achieving a $\texttt{CV}$ below $10^{-5}$ is considered to have high accuracy and there is little benefit to go below that. Instead, when methods have sufficiently low $\texttt{CV}$, having a low $\texttt{RS}$ is better.
- Learning curves: Progress on $\texttt{RS}$ and $\texttt{CV}$ over wall-clock time on the validation set.
- Single inference time: The time required to solve one instance at test time.
- Batch inference time: The time required to solve a batch of $1024$ instances at test time.
We report the results for an optimization problem with optimization variable of dimension $d$, $n_{\mathrm{eq}}$ equality and $n_{\mathrm{ineq}}$ inequality convex constraints and with a non-convex objective. Here, we use a small and a large (in the parametric optimization sense) datasets $(d, n_{\mathrm{eq}}, n_{\mathrm{ineq}}) \in {(100, 50, 50), (1000, 500, 500)}$.
Overall, Πnet outperforms the state-of-the-art in accuracy and training times. For more comparisons and ablations, please check out our paper.
Reproducing the paper's results
To reproduce our benchmarks and ablations, you can run
python -m src.benchmarks.QP.run_QP --seed 0 --id <ID> --config <CONFIG> --proj_method <METHOD>
To select ID, CONFIG, and METHOD, please refer to the table below.
| Experiment | ID | CONFIG | METHOD |
|---|---|---|---|
| Πnet results on the small convex benchmark | TODO | benchmark_small_autotune | pinet |
| Πnet results on the large convex benchmark | TODO | benchmark_large_autotune | pinet |
| Πnet results on the small non-convex benchmark | TODO | benchmark_small_autotune | pinet |
| Πnet results on the large non-convex benchmark | TODO | benchmark_large_autotune | pinet |
| Πnet results on the TODO benchmark with manual tuning | TODO | benchmark_config_manual | pinet |
| Πnet results on the TODO benchmark without equilibration | TODO | benchmark_noequil_autotune | pinet |
| cvxpy results on the small convex benchmark | TODO | benchmark_cvxpy | cvxpy |
| jaxopt results on the small convex benchmark | TODO | benchmark_jaxopt_small | jaxopt |
| jaxopt results on the large convex benchmark | TODO | benchmark_jaxopt_large | jaxopt |
| jaxopt results on the small non-convex benchmark | TODO | benchmark_jaxopt_small | jaxopt |
| jaxopt results on the large non-convex benchmark | TODO | benchmark_jaxopt_small | jaxopt |
[!WARNING] Generating the large dataset The repo contains only the data to run the small benchmark. For the large one, you need first to generate the data. For this, please run
TODONOTE: This may take a while... In a future release, we plan to provide several datasets with Hugging face 🤗 or similar providers, and this step will be less tedious.
For DC3, we used the open-source implementation.
[!TIP] With Docker 🐳 To run the above commands within th docker container, you can use
docker compose run --rm pinet-cpu -m src.benchmarks.QP.run_QP --seed 0 --id <ID> --config <CONFIG> --proj_method <METHOD> # run on CPU docker compose run --rm pinet-gpu -m src.benchmarks.QP.run_QP --seed 0 --id <ID> --config <CONFIG> --proj_method <METHOD> # run on GPU
For the toy MPC, please refer to the examples section. For the second-order cone constraints, you can use this notebook.
Citation 🙏
If you use this code in your research, please cite our paper:
@inproceedings{grontas2025pinet,
title={Pinet: Optimizing hard-constrained neural networks with orthogonal projection layers},
author={Grontas, Panagiotis and Terpin, Antonio and Balta C., Efe and D'Andrea, Raffaello and Lygeros, John},
journal={arXiv preprint arXiv:TODO},
year={2025}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pinet_hcnn-0.1.0.tar.gz.
File metadata
- Download URL: pinet_hcnn-0.1.0.tar.gz
- Upload date:
- Size: 34.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c95139f3356a3da1cb1f99ff9ef2ebbd124794f5e3a5ae466a89034cd5477907
|
|
| MD5 |
a9eae8e261fd0639bfdeeb0c25aaadf5
|
|
| BLAKE2b-256 |
393b095c7d9298844ba49bfc350a7b737e55396c84e5281e48271acadf95242c
|
File details
Details for the file pinet_hcnn-0.1.0-py3-none-any.whl.
File metadata
- Download URL: pinet_hcnn-0.1.0-py3-none-any.whl
- Upload date:
- Size: 31.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a7534b35e613826073ab08861e3d96ae420c8238305efa74db7d73d4828b76b7
|
|
| MD5 |
94ac15578d14b95f59cb425b02dc664a
|
|
| BLAKE2b-256 |
32ccfa911d629d3e837939a0d7a7db5da6249c45c9ef65fef018692a193e09d8
|