Skip to main content

A differentiable orthogonal projection layer for training hard-constrained neural networks.

Project description

Πnet: Optimizing hard-constrained neural networks with orthogonal projection layers

arXiv GitHub stars License: Apache 2.0 codecov Tests PyPI version

Follow Panos Follow Antonio

Cover Image

This repository contains a JAX implementation of Πnet, an output layer for neural networks that ensures the satisfaction of specified convex constraints.

[!NOTE] TL;DR Πnet leverages operator splitting for rapid and reliable projections in the forward pass, and the implicit function theorem for backpropagation. It offers a feasible-by-design optimization proxy for parametric constrained optimization problems to obtain modest-accuracy solutions faster than traditional solvers when solving a single problem, and significantly faster for a batch of problems.

Index

Getting started

To install Πnet, run:

  • CPU-only (Linux/macOS/Windows)
    pip install pinet-hcnn
    
  • GPU (NVIDIA, CUDA 12)
    pip install "pinet-hcnn[cuda12]"
    

[!WARNING] CUDA dependencies If you have issues with CUDA drivers, please follow the official instructions for cuda12 and cudnn (Note: wheels only available on linux). If you have issues with conflicting CUDA libraries, check also this issue... or use our Docker container 🤗.

We also provide a working Docker image to reproduce the results of the paper and to build on top.

docker compose run --rm pinet-cpu # Run the pytests on CPU
docker compose run --rm pinet-gpu # Run the pytests on GPU

[!WARNING] CUDA dependencies Running the Docker container with GPU support requires NVIDIA Container Toolkit on the host.

See also the section on reproducing the paper's results for more examples of commands.

Supported platforms 💻

Linux x86_64 Linux aarch64 Mac aarch64 Windows x86_64 Windows WSL2 x86_64
CPU
NVIDIA GPU n/a

Examples

Constraints & Projection Layer

All tensors are batched. Let B = batch size (you may use B=1 to broadcast across a batch).

  • Vectors: shape (B, n, 1)
  • Matrices: shape (B, n, d)

EqualityConstraint — enforce A @ x == b

import jax.numpy as jnp
from pinet import EqualityConstraint

B, n_eq, d = 4, 3, 5
A = jnp.zeros((1, n_eq, d))         # (1, n_eq, d)  # broadcast across batch
b = jnp.zeros((B, n_eq, 1))         # (B, n_eq, 1)

eq = EqualityConstraint(
    A=A,
    b=b,
    method=None,                    # let Project decide / lift later
    var_b=True,                     # b provided per-batch at runtime
    var_A=False,                    # A constant (broadcasted)
)

[!WARNING] method=None eq.project() is only available if method="pinv". When you have multiple constraints and you plan on using the equality constraint only within the projection layer, you can leave method=None (as above).

AffineInequalityConstraint — enforce lb ≤ C @ x ≤ ub

import jax.numpy as jnp
from pinet import AffineInequalityConstraint

n_ineq = 7
C  = jnp.zeros((1, n_ineq, d))      # (1, n_ineq, d)
lb = jnp.full((B, n_ineq, 1), -1.0) # (B, n_ineq, 1)
ub = jnp.full((B, n_ineq, 1),  1.0) # (B, n_ineq, 1)

ineq = AffineInequalityConstraint(C=C, lb=lb, ub=ub)

[!WARNING] ineq.project() intentionally NotImplemented To improve the efficiency of the projection, we always "lift" the affine inequality constraints as described in the paper. For this, we did not even bother implementing the projection method for this type of constraints 🤗.

BoxConstraint — clip selected dimensions

import numpy as np
import jax.numpy as jnp
from pinet import BoxConstraint, BoxConstraintSpecification

lb_x = jnp.full((B, d, 1), -2.0)    # (B, d, 1)
ub_x = jnp.full((B, d, 1),  2.0)    # (B, d, 1)
mask = np.ones(d, dtype=bool)       # apply to all dims (use False to skip dims)

box = BoxConstraint(BoxConstraintSpecification(lb=lb_x, ub=ub_x, mask=mask))
# box.project(...) clips x[:, mask, :] into [lb_x, ub_x].

Combine constraints with Project (Douglas–Rachford)

Project handles:

  • Lifting inequalities into equalities + auxiliary variables;
  • Optional Ruiz equilibration;
  • JIT-compiled forward;
  • Optional custom VJP for backprop.
from pinet.project import Project
from pinet.dataclasses import ProjectionInstance
import jax.numpy as jnp

proj = Project(
    eq_constraint=eq,              # can be None
    ineq_constraint=ineq,          # can be None
    box_constraint=box,            # can be None
    unroll=False,                  # use custom VJP path by default
)

# Build a ProjectionInstance with the point to project and (optionally) runtime specs:
x0 = jnp.zeros((B, d, 1))
yraw = ProjectionInstance(x=x0)
# If var_b=True and you supply per-batch b at runtime, pass it via your dataclass, e.g.:
# yraw = yraw.update(eq=yraw.eq.update(b=b))

y, sK = proj.call(       # JIT-compiled projector
    yraw=yraw,
    n_iter=50,                    # Douglas-Rachford iterations
    n_iter_backward=100,          # Maximum number of iterations for the bicgstab algorithm
    sigma=1.0, omega=1.7,
)

# If you want to resume the projection with the latest governing sequence sK,
# you can provided to the call method via s0=sK.

cv = proj.cv(y)  # (B, 1, 1) max violation across constraints
                 # The CV can also be assessed for the different constraints separately,
                 # e.g., eq.cv(y), if eq is a constraint for y
                 # (shapes need to match, so be careful of lifting!)

Notes

  • Batch rules: For each pair of tensors (X, Y), either batch sizes match or one is 1 (broadcast).
  • Equality method: Use method="pinv" when you rely on the equality projector standalone. When used inside Project, you can keep method=None; lifting will set up the pseudo-inverse internally.
  • Dimensions after lifting: If inequalities are present, the internal lifted dimension is d + n_ineq (auxiliary variables).

Minimal “Toy MPC” Application

The helper below wires the projector into a Pinet model; the loss is your batched objective.

# benchmarks/toy_MPC/model.py
import jax.numpy as jnp
from flax import linen as nn
from pinet import BoxConstraint, BoxConstraintSpecification, EqualityConstraint
from src.benchmarks.model import build_model_and_train_step, setup_pinet

def setup_model(rng_key, hyperparameters, A, X, b, lb, ub, batched_objective):
    activation = getattr(nn, hyperparameters["activation"])
    if activation is None:
        raise ValueError(f"Unknown activation: {hyperparameters['activation']}")

    # Constraints (b varies at runtime; A is constant & broadcasted)
    eq  = EqualityConstraint(A=A, b=b, method=None, var_b=True)
    box = BoxConstraint(BoxConstraintSpecification(lb=lb, ub=ub))
    project, project_test, _ = setup_pinet(eq_constraint=eq, box_constraint=box,
                                           hyperparameters=hyperparameters)

    model, params, train_step = build_model_and_train_step(
        rng_key=rng_key,
        dim=A.shape[2],
        features_list=hyperparameters["features_list"],
        activation=activation,
        project=project,                # projector in the training graph
        project_test=project_test,      # projector used at eval
        raw_train=hyperparameters.get("raw_train", False),
        raw_test=hyperparameters.get("raw_test", False),
        loss_fn=lambda preds, _b: batched_objective(preds),
        example_x=X[:1, :, 0],
        example_b=b[:1],
        jit=True,
    )
    return model, params, train_step

Run the end-to-end script

To reproduce the results in the paper, you can run

python -m src.benchmarks.toy_MPC.run_toy_MPC --filename toy_MPC_seed42_examples10000.npz --config toy_MPC --seed 12

To generate the dataset, run

TODO

You’ll get:

  • Training logs (loss, CV, timing),
  • Validation/Test metrics incl. relative suboptimality & CV,
  • Saved params & results ready to reload and plot trajectories.

[!TIP] Troubleshooting All the objects in pinet.dataclasses offer a validate methods, which can be used to verify your inputs.

Works using Πnet ⚙️

We collect here applications using Πnet. Please feel free to open a pull request to add yours! 🤗

Link Project
View Repo Multi-vehicle trajectory optimization with non-convex preferences
This project features contexts dimensions in the millions and tens of thousands of optimization variables.

Contributing ☕️

Contributions are more than welcome! 🙏 Please check out our contributing page, and feel free to open an issue for problems and feature requests⚠️.

Benchmarks 📈

Below, we summarize the performance gains of Πnet over state-of-the-art methods. We consider the following metrics:

  • Relative Suboptimality ($\texttt{RS}$): The suboptimality of a candidate solution $\hat{y}$ compared to the optimal objective $J(y^{\star})$, computed by a high-accuracy solver.
  • Constraint Violation ($\texttt{CV}$): Maximum violation ($\infty$-norm) of any constraint (equality and inequality). In practice, any solver achieving a $\texttt{CV}$ below $10^{-5}$ is considered to have high accuracy and there is little benefit to go below that. Instead, when methods have sufficiently low $\texttt{CV}$, having a low $\texttt{RS}$ is better.
  • Learning curves: Progress on $\texttt{RS}$ and $\texttt{CV}$ over wall-clock time on the validation set.
  • Single inference time: The time required to solve one instance at test time.
  • Batch inference time: The time required to solve a batch of $1024$ instances at test time.

We report the results for an optimization problem with optimization variable of dimension $d$, $n_{\mathrm{eq}}$ equality and $n_{\mathrm{ineq}}$ inequality convex constraints and with a non-convex objective. Here, we use a small and a large (in the parametric optimization sense) datasets $(d, n_{\mathrm{eq}}, n_{\mathrm{ineq}}) \in {(100, 50, 50), (1000, 500, 500)}$.

Non-convex CV and RS Non-convex learning curves

Overall, Πnet outperforms the state-of-the-art in accuracy and training times. For more comparisons and ablations, please check out our paper.

Reproducing the paper's results

To reproduce our benchmarks and ablations, you can run

python -m src.benchmarks.QP.run_QP --seed 0 --id <ID> --config <CONFIG>  --proj_method <METHOD>

To select ID, CONFIG, and METHOD, please refer to the table below.

Experiment ID CONFIG METHOD
Πnet results on the small convex benchmark TODO benchmark_small_autotune pinet
Πnet results on the large convex benchmark TODO benchmark_large_autotune pinet
Πnet results on the small non-convex benchmark TODO benchmark_small_autotune pinet
Πnet results on the large non-convex benchmark TODO benchmark_large_autotune pinet
Πnet results on the TODO benchmark with manual tuning TODO benchmark_config_manual pinet
Πnet results on the TODO benchmark without equilibration TODO benchmark_noequil_autotune pinet
cvxpy results on the small convex benchmark TODO benchmark_cvxpy cvxpy
jaxopt results on the small convex benchmark TODO benchmark_jaxopt_small jaxopt
jaxopt results on the large convex benchmark TODO benchmark_jaxopt_large jaxopt
jaxopt results on the small non-convex benchmark TODO benchmark_jaxopt_small jaxopt
jaxopt results on the large non-convex benchmark TODO benchmark_jaxopt_small jaxopt

[!WARNING] Generating the large dataset The repo contains only the data to run the small benchmark. For the large one, you need first to generate the data. For this, please run

TODO

NOTE: This may take a while... In a future release, we plan to provide several datasets with Hugging face 🤗 or similar providers, and this step will be less tedious.

For DC3, we used the open-source implementation.

[!TIP] With Docker 🐳 To run the above commands within th docker container, you can use

docker compose run --rm pinet-cpu -m src.benchmarks.QP.run_QP --seed 0 --id <ID> --config <CONFIG>  --proj_method <METHOD> # run on CPU
docker compose run --rm pinet-gpu -m src.benchmarks.QP.run_QP --seed 0 --id <ID> --config <CONFIG>  --proj_method <METHOD> # run on GPU

For the toy MPC, please refer to the examples section. For the second-order cone constraints, you can use this notebook.

Citation 🙏

If you use this code in your research, please cite our paper:

   @inproceedings{grontas2025pinet,
     title={Pinet: Optimizing hard-constrained neural networks with orthogonal projection layers},
     author={Grontas, Panagiotis and Terpin, Antonio and Balta C., Efe and D'Andrea, Raffaello and Lygeros, John},
     journal={arXiv preprint arXiv:TODO},
     year={2025}
   }

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pinet_hcnn-0.1.0.tar.gz (34.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pinet_hcnn-0.1.0-py3-none-any.whl (31.5 kB view details)

Uploaded Python 3

File details

Details for the file pinet_hcnn-0.1.0.tar.gz.

File metadata

  • Download URL: pinet_hcnn-0.1.0.tar.gz
  • Upload date:
  • Size: 34.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.8

File hashes

Hashes for pinet_hcnn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c95139f3356a3da1cb1f99ff9ef2ebbd124794f5e3a5ae466a89034cd5477907
MD5 a9eae8e261fd0639bfdeeb0c25aaadf5
BLAKE2b-256 393b095c7d9298844ba49bfc350a7b737e55396c84e5281e48271acadf95242c

See more details on using hashes here.

File details

Details for the file pinet_hcnn-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pinet_hcnn-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 31.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.8

File hashes

Hashes for pinet_hcnn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a7534b35e613826073ab08861e3d96ae420c8238305efa74db7d73d4828b76b7
MD5 94ac15578d14b95f59cb425b02dc664a
BLAKE2b-256 32ccfa911d629d3e837939a0d7a7db5da6249c45c9ef65fef018692a193e09d8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page