Skip to main content

LinSATNet offers a neural network layer to enforce the satisfiability of positive linear constraints to the output of neural networks. The gradient through the layer is exactly computed. This package now works with PyTorch.

Project description

LinSATNet

This is the official implementation of our ICML 2023 paper "LinSATNet: The Positive Linear Satisfiability Neural Networks".

With LinSATNet, you can enforce the satisfiability of general positive linear constraints to the output of neural networks.

usecase

The LinSAT layer is fully differentiable, and the gradients are exactly computed. Our implementation now supports PyTorch.

You can install it by

pip install linsatnet

And get started by

from LinSATNet import linsat_layer

A Quick Example

There is a quick example if you run LinSATNet/linsat.py directly. In this example, the doubly-stochastic constraint is enforced for 3x3 variables.

To run the example, first clone the repo:

git clone https://github.com/Thinklab-SJTU/LinSATNet.git

Go into the repo, and run the example code:

cd LinSATNet
python LinSATNet/linsat.py

In this example, we try to enforce doubly-stochastic constraint to a 3x3 matrix. The doubly-stochastic constraint means that all rows and columns of the matrix should sum to 1.

The 3x3 matrix is flattened into a vector, and the following positive linear constraints are considered (for $\mathbf{E}\mathbf{x}=\mathbf{f}$):

E = torch.tensor(
    [[1, 1, 1, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 1, 1, 1, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 1, 1, 1],
     [1, 0, 0, 1, 0, 0, 1, 0, 0],
     [0, 1, 0, 0, 1, 0, 0, 1, 0],
     [0, 0, 1, 0, 0, 1, 0, 0, 1]], dtype=torch.float32
)
f = torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)

We randomly init w and regard it as the output of some neural networks:

w = torch.rand(9) # w could be the output of neural network
w = w.requires_grad_(True)

We also have a "ground-truth target" for the output of linsat_layer, which is an orthogonal matrix in this example:

x_gt = torch.tensor(
    [1, 0, 0,
     0, 1, 0,
     0, 0, 1], dtype=torch.float32
)

The forward/backward passes of LinSAT follow the standard PyTorch style and are readily integrated into existing deep learning pipelines.

The forward pass:

linsat_outp = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0)

The backward pass:

loss = ((linsat_outp - x_gt) ** 2).sum()
loss.backward()

We can also do gradient-based optimization over w to make the output of linsat_layer closer to x_gt. This is what's happening when you train a neural network.

niters = 10
opt = torch.optim.SGD([w], lr=0.1, momentum=0.9)
for i in range(niters):
    x = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0)
    cv = torch.matmul(E, x.t()).t() - f.unsqueeze(0)
    loss = ((x - x_gt) ** 2).sum()
    loss.backward()
    opt.step()
    opt.zero_grad()
    print(f'{i}/{niters}\n'
          f'  underlying obj={torch.sum(w * x)},\n'
          f'  loss={loss},\n'
          f'  sum(constraint violation)={torch.sum(cv[cv > 0])},\n'
          f'  x={x},\n'
          f'  constraint violation={cv}')

And you are likely to see the loss decreasing during the gradient steps.

API Reference

To use LinSATNet in your own project, make sure you have the package installed:

pip install linsatnet

and import the pacakge at the beginning of your code:

from LinSATNet import linsat_layer

The linsat_layer function

LinSATNet.linsat_layer(x, A=None, b=None, C=None, d=None, E=None, f=None, tau=0.05, max_iter=100, dummy_val=0, mode='v1') [source]

LinSAT layer enforces positive linear constraints to the input x and projects it with the constraints $$\mathbf{A} \mathbf{x} <= \mathbf{b}, \mathbf{C} \mathbf{x} >= \mathbf{d}, \mathbf{E} \mathbf{x} = \mathbf{f}$$ and all elements in $\mathbf{A}, \mathbf{b}, \mathbf{C}, \mathbf{d}, \mathbf{E}, \mathbf{f}$ must be non-negative.

Parameters:

  • x: PyTorch tensor of size ($n_v$), it can optionally have a batch size ($b \times n_v$)
  • A, C, E: PyTorch tensor of size ($n_c \times n_v$), constraint matrix on the left hand side
  • b, d, f: PyTorch tensor of size ($n_c$), constraint vector on the right hand side
  • tau: (default=0.05) parameter to control the discreteness of the projection. Smaller value leads to more discrete (harder) results, larger value leads to more continuous (softer) results.
  • max_iter: (default=100) max number of iterations
  • dummy_val: (default=0) the value of dummy variables appended to the input vector
  • mode: (default='v1') EXPERIMENTAL the mode of LinSAT kernel. v2 is sometimes faster than v1.

return: PyTorch tensor of size ($n_v$) or ($b \times n_v$), the projected variables

Notations:

  • $b$ means the batch size.
  • $n_c$ means the number of constraints ($\mathbf{A}$, $\mathbf{C}$, $\mathbf{E}$ may have different $n_c$)
  • $n_v$ means the number of variables

Some practical notes

  1. You must ensure that your input constraints have a non-empty feasible space. Otherwise, linsat_layer will not converge.
  2. You may tune the value of tau for your specific tasks. Monitor the output of LinSAT so that the "smoothness" of the output meets your task. Reasonable choices of tau may range from 1e-4 to 100 in our experience.
  3. Be careful of potential numerical issues. Sometimes A x <= 1 does not work, but A x <= 0.999 works.
  4. The input vector x may have a batch dimension, but the constraints can not have a batch dimension. The constraints should be consistent for all data in one batch.

Citation

If you find our paper/code useful in your research, please cite

@inproceedings{WangICML23,
  title={LinSATNet: The Positive Linear Satisfiability Neural Networks},
  author={Wang, Runzhong and Zhang, Yunhao and Guo, Ziao and Chen, Tianyi and Yang, Xiaokang and Yan, Junchi},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

LinSATNet-0.0.6.tar.gz (12.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

LinSATNet-0.0.6-py3-none-any.whl (8.2 kB view details)

Uploaded Python 3

File details

Details for the file LinSATNet-0.0.6.tar.gz.

File metadata

  • Download URL: LinSATNet-0.0.6.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.17

File hashes

Hashes for LinSATNet-0.0.6.tar.gz
Algorithm Hash digest
SHA256 a1ce2fdf70bdf324866de296d8a43df1f294b67b00cb24609e784c442a7da117
MD5 944f8aaa9e4347867ebbb5255be893fd
BLAKE2b-256 7a4abddbd5718e0c64587f2cd9296e6cfc1f66b7530bbe267ed9bd6a755291f0

See more details on using hashes here.

File details

Details for the file LinSATNet-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: LinSATNet-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 8.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.17

File hashes

Hashes for LinSATNet-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 ed9a14ab85760c91e35e03da58679e8ea3a41947c50c3f8513ad4b851cce5c41
MD5 d3a2cf430bf0b4509e41937ba23ab315
BLAKE2b-256 cd171a30a14a280af62cbaeef642fbbfa24b6e602c353ef14255754ff4efb67b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page