Skip to main content

LinSATNet offers a neural network layer to enforce the satisfiability of positive linear constraints to the output of neural networks. The gradient through the layer is exactly computed. This package now works with PyTorch.

Project description

LinSATNet

This is the official implementation of our ICML 2023 paper "LinSATNet: The Positive Linear Satisfiability Neural Networks".

With LinSATNet, you can enforce the satisfiability of general positive linear constraints to the output of neural networks.

usecase

News LinSATNet supports sparse constraints starting v0.1.1! The forward and backward should be identical to the dense version, but the sparse version is expected to be more efficient in time and GPU memory. Upgrade by pip install --upgrade linsatnet


The LinSAT layer is fully differentiable, and the gradients are exactly computed. Our implementation now supports PyTorch.

You can install it by

pip install linsatnet

And get started by

from LinSATNet import linsat_layer

A Quick Example

There is a quick example if you run LinSATNet/linsat.py directly. In this example, the doubly-stochastic constraint is enforced for 3x3 variables.

To run the example, first clone the repo:

git clone https://github.com/Thinklab-SJTU/LinSATNet.git

Go into the repo, and run the example code:

cd LinSATNet
python LinSATNet/linsat.py

In this example, we try to enforce doubly-stochastic constraint to a 3x3 matrix. The doubly-stochastic constraint means that all rows and columns of the matrix should sum to 1.

The 3x3 matrix is flattened into a vector, and the following positive linear constraints are considered (for $\mathbf{E}\mathbf{x}=\mathbf{f}$):

E = torch.tensor(
    [[1, 1, 1, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 1, 1, 1, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 1, 1, 1],
     [1, 0, 0, 1, 0, 0, 1, 0, 0],
     [0, 1, 0, 0, 1, 0, 0, 1, 0],
     [0, 0, 1, 0, 0, 1, 0, 0, 1]], dtype=torch.float32
)
f = torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)

We randomly init w and regard it as the output of some neural networks:

w = torch.rand(9) # w could be the output of neural network
w = w.requires_grad_(True)

We also have a "ground-truth target" for the output of linsat_layer, which is an diagonal matrix in this example:

x_gt = torch.tensor(
    [1, 0, 0,
     0, 1, 0,
     0, 0, 1], dtype=torch.float32
)

The forward/backward passes of LinSAT follow the standard PyTorch style and are readily integrated into existing deep learning pipelines.

The forward pass:

linsat_outp = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0)

The backward pass:

loss = ((linsat_outp - x_gt) ** 2).sum()
loss.backward()

You can also set E as a sparse matrix to improve the time & memory efficiency (especially for large-sized input):

linsat_outp = linsat_layer(w, E=E.to_sparse(), f=f, tau=0.1, max_iter=10, dummy_val=0)

We can also do gradient-based optimization over w to make the output of linsat_layer closer to x_gt. This is what's happening when you train a neural network.

niters = 10
opt = torch.optim.SGD([w], lr=0.1, momentum=0.9)
for i in range(niters):
    x = linsat_layer(w, E=E, f=f, tau=0.1, max_iter=10, dummy_val=0)
    cv = torch.matmul(E, x.t()).t() - f.unsqueeze(0)
    loss = ((x - x_gt) ** 2).sum()
    loss.backward()
    opt.step()
    opt.zero_grad()
    print(f'{i}/{niters}\n'
          f'  underlying obj={torch.sum(w * x)},\n'
          f'  loss={loss},\n'
          f'  sum(constraint violation)={torch.sum(cv[cv > 0])},\n'
          f'  x={x},\n'
          f'  constraint violation={cv}')

And you are likely to see the loss decreasing during the gradient steps.

API Reference

To use LinSATNet in your own project, make sure you have the package installed:

pip install linsatnet

and import the pacakge at the beginning of your code:

from LinSATNet import linsat_layer, init_constraints

The linsat_layer function

LinSATNet.linsat_layer(x, A=None, b=None, C=None, d=None, E=None, f=None, constr_dict=None, tau=0.05, max_iter=100, dummy_val=0, mode='v2', grouped=True, no_warning=False) [source]

LinSAT layer enforces positive linear constraints to the input x and projects it with the constraints $$\mathbf{A} \mathbf{x} <= \mathbf{b}, \mathbf{C} \mathbf{x} >= \mathbf{d}, \mathbf{E} \mathbf{x} = \mathbf{f}$$ and all elements in $\mathbf{A}, \mathbf{b}, \mathbf{C}, \mathbf{d}, \mathbf{E}, \mathbf{f}$ must be non-negative.

Parameters:

  • x: PyTorch tensor of size ($n_v$), it can optionally have a batch size ($b \times n_v$)
  • A, C, E: PyTorch tensor of size ($n_c \times n_v$), constraint matrix on the left hand side
  • b, d, f: PyTorch tensor of size ($n_c$), constraint vector on the right hand side
  • constr_dict: a dictionary with initialized constraint information, which is the output of the function LinSATNet.init_constraints. Specifying this variable could avoid re-initializing the constraints for the same constraints and improve the efficiency
  • tau: (default=0.05) parameter to control the discreteness of the projection. Smaller value leads to more discrete (harder) results, larger value leads to more continuous (softer) results.
  • max_iter: (default=100) max number of iterations
  • dummy_val: (default=0) the value of dummy variables appended to the input vector
  • mode: (default='v2') LinSAT kernel implementation. v1 is the one came with the ICML paper, v2 is the improved version with (usually) better efficiency
  • grouped: (default=True) group non-overlapping constraints in one operation for better efficiency
  • no_warning: (default=False) turn off warning message

return: PyTorch tensor of size ($n_v$) or ($b \times n_v$), the projected variables

Notations:

  • $b$ means the batch size.
  • $n_c$ means the number of constraints ($\mathbf{A}$, $\mathbf{C}$, $\mathbf{E}$ may have different $n_c$)
  • $n_v$ means the number of variables

Some practical notes

  1. You must ensure that your input constraints have a non-empty feasible space. Otherwise, linsat_layer will not converge. It is also worth noting that x is in the range of [0, 1], and you may add a multiplier to scale it.
  2. You may tune the value of tau for your specific tasks. Monitor the output of LinSAT so that the "smoothness" of the output meets your task. Reasonable choices of tau may range from 1e-4 to 100 in our experience.
  3. Be careful of potential numerical issues. Sometimes A x <= 1 does not work, but A x <= 0.999 works.
  4. The input vector x may have a batch dimension, but the constraints can not have a batch dimension. The constraints should be consistent for all data in one batch.
  5. Input constraints as sparse tensors can usually help you save GPU memory. When working with sparse constraints, A, C, E should be torch.sparse_coo_tensor, and b, d, f should be dense tensors.

Citation

If you find our paper/code useful in your research, please cite

@inproceedings{WangICML23,
  title={{LinSATNet}: The Positive Linear Satisfiability Neural Networks},
  author={Wang, Runzhong and Zhang, Yunhao and Guo, Ziao and Chen, Tianyi and Yang, Xiaokang and Yan, Junchi},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

LinSATNet-0.1.3.tar.gz (16.5 kB view details)

Uploaded Source

Built Distribution

LinSATNet-0.1.3-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file LinSATNet-0.1.3.tar.gz.

File metadata

  • Download URL: LinSATNet-0.1.3.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for LinSATNet-0.1.3.tar.gz
Algorithm Hash digest
SHA256 53b4e42571e00c425511d7d74347404e703d45b093a01d9740e3e68a20388876
MD5 820974ed2f129bf47494b2bd51b1e646
BLAKE2b-256 e34e2e1c8a8f66c9365d0bd5458e1ae7790a277e7fbeaf7a00c3d43d866bf9f7

See more details on using hashes here.

File details

Details for the file LinSATNet-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: LinSATNet-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for LinSATNet-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 2d0023dc4a2c42759b6fc3dfd2a68784759b7f7b7a71d8424d101899b0008e33
MD5 264c7b02178c2738a1937471230ca575
BLAKE2b-256 38706a2af7c8012c2cd7107fe033b86145badd2b444bed7798b3ae443cc16bd7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page