A NeSy Framework for Learning with Requirements
Project description
PiShield: A NeSy Framework for Learning with Requirements
- :sparkles: Description
- :pushpin: Dependencies
- :hammer_and_wrench: Installation
- :bulb: Usage
- :arrow_forward: Demo video
- :fire: Performance
- Team
- :memo: References
:sparkles: Description
Update: PiShield's website is now available here.
PiShield is the first framework ever allowing for the integration of the requirements into the neural networks' topology.
:white_check_mark: The integration happens in a straightforward and efficient manner and allows for the creation of deep learning models that are guaranteed to be compliant with the given requirements, no matter the input.
:pencil2: The requirements can be integrated both at inference and/or training time, depending on the practitioners' needs.
:pushpin: Dependencies
PiShield requires Python 3.8 or later and PyTorch.
Optional step: conda environment setup, using cpu-only PyTorch here. Different PyTorch versions can be specified following the instructions here.
conda create -n "pishield" python=3.11 ipython
conda activate pishield
conda install pytorch cpuonly -c pytorch
pip install -r requirements.txt
:hammer_and_wrench: Installation
From the root of this repository, containing setup.py
, run:
pip install .
:bulb: Usage
Assume we have the following constraints and ordering of the variables in a file example_constraints_tabular.txt
:
ordering y_0 y_1 y_2
-y_0 >= -3
y_0 >= 3
y_0 - y_1 >= 0
-y_0 - y_2 >= 0
Inference time
To correct predictions at inference time such that they satisfy the constraints, we can use PiShield as follows:
import torch
from pishield.constraint_layer import build_pishield_layer
predictions = torch.tensor([[-5., -2., -1.]])
constraints_path = 'example_constraints_tabular.txt'
num_variables = predictions.shape[-1]
CL = build_pishield_layer(num_variables, constraints_path)
# apply CL from PiShield to get the corrected predictions
corrected_predictions = CL(predictions.clone()) # returns tensor([[ 3., -2., -3.]], which satisfies the constraints
import torch
from pishield.constraint_layer import build_pishield_layer
def correct_predictions(predictions: torch.Tensor, constraints_path: str):
num_variables = predictions.shape[-1]
# build a constraint layer CL using PiShield
CL = build_pishield_layer(num_variables, constraints_path)
# apply PiShield to get corrected predictions, which satisfy the constraints
corrected_predictions = CL(predictions)
return corrected_predictions
Training time
Assume a Deep Generative Model (DGM) is used to obtain synthetic tabular data. Using PiShield at training time is easy, as it requires two steps:
- Instantiating the ConstraintLayer class from PiShield in the DGM's constructor.
- Applying the ConstraintLayer on the generated data obtained from the DGM before updating the loss function of the DGM.
:arrow_forward: Demo video
A demo video is available for download here.
:fire: Performance
1. Autonomous Driving
In [2], we considered standard 3D-RetinaNet models with different temporal learning architectures such as I3D, C2D, RCN, RCGRU, RCLSTM, and SlowFast, and compared each of these with their constrained versions. The constrained versions inject propositional background knowledge into the models via a constrained layer, as we call it in [2], which is equivalent to a Shield layer when using PiShield.
Below we report the aggregated performance from Table 2 of our paper [2] to show the results we obtained according to the f-mAP (framewise mean Average Precision) measure, at IOU (Intersection-over-Union) threshold of 0.5. The best results are in bold.
As we can see, the models incorporating background knowledge through Shield layers outperform their standard counterparts.
Baseline | Shielded | |
---|---|---|
I3D | 29.30 | 30.98 |
C2D | 26.34 | 27.93 |
RCN | 29.26 | 30.02 |
RCGRU | 29.24 | 30.50 |
RCLSTM | 28.93 | 30.42 |
SlowFast | 29.73 | 31.88 |
___________ | ||
AVERAGE | 28.80 | 30.29 |
2. Tabular Data Generation
In [1], we compared standard deep generative models with their respective constrained versions, which use linear inequality constraints. The latter are the models to which we added a constraint layer, as we call it in [1], which is equivalent to a Shield layer when using PiShield.
Below we reproduce Table 2 of our paper [1] to show the results we obtained according to two standard measures for tabular data generation benchmarks: utility and detection. For each of these two measures, we report the performance using three metrics: F1-score (F1), weighted F1-score (wF1), and Area Under the ROC Curve (AUC). The best results are in bold.
As we can see, in 28/30 cases, the models incorporating background knowledge through Shield layers outperform their standard counterparts.
Utility(↑) | Detection(↓) | |||||
---|---|---|---|---|---|---|
F1 | wF1 | AUC | F1 | wF1 | AUC | |
WGAN | 0.463 | 0.488 | 0.730 | 0.945 | 0.943 | 0.954 |
Shielded-WGAN | 0.483 | 0.502 | 0.745 | 0.915 | 0.912 | 0.934 |
TableGAN | 0.330 | 0.400 | 0.704 | 0.908 | 0.907 | 0.926 |
Shielded-TableGAN | 0.375 | 0.432 | 0.714 | 0.898 | 0.895 | 0.917 |
CTGAN | 0.517 | 0.532 | 0.771 | 0.902 | 0.901 | 0.920 |
Shielded-CTGAN | 0.516 | 0.537 | 0.773 | 0.894 | 0.891 | 0.919 |
TVAE | 0.497 | 0.527 | 0.767 | 0.869 | 0.868 | 0.892 |
Shielded-TVAE | 0.507 | 0.537 | 0.773 | 0.868 | 0.867 | 0.898 |
GOGGLE | 0.344 | 0.373 | 0.624 | 0.926 | 0.926 | 0.943 |
Shielded-GOGGLE | 0.409 | 0.427 | 0.667 | 0.925 | 0.916 | 0.937 |
________________ | ||||||
AVERAGE Baseline | 0.430 | 0.464 | 0.719 | 0.910 | 0.909 | 0.927 |
AVERAGE Shielded | 0.458 | 0.487 | 0.734 | 0.900 | 0.896 | 0.921 |
3. Functional Genomics
In our paper [3], we compared the performance of baseline models with their constrained counterparts. The latter are the models to which we added a constraint layer, as we call it in [3], which is equivalent to a Shield layer when using PiShield.
Below we aggregated the results from Table 3 of [3] and reported the performance in terms of the area under the average precision and recall curve (AU(PRC)), which is a standard metric used in functional genomics. As we can see, the models Shield layers outperform their standard counterparts.
Dataset | Baseline* | PiShield |
---|---|---|
CELLCYCLE | 0.220 | 0.232 |
DERISI | 0.179 | 0.182 |
EISEN | 0.262 | 0.285 |
EXPRE | 0.246 | 0.270 |
GASCH1 | 0.239 | 0.261 |
GASCH2 | 0.221 | 0.235 |
SEQ | 0.245 | 0.274 |
SPO | 0.186 | 0.190 |
__________ | ||
AVERAGE | 0.225 | 0.241 |
*Note: All baselines for the functional genomics scenario have a postprocessing step included, as functional genomics tasks always require that the constraints are satisfied.
Team
Mihaela Catalina Stoian [GitHub][LinkedIn][Google Scholar]
Alex Tatomir [GitHub][LinkedIn][Google Scholar]
Thomas Lukasiewicz [Google Scholar]
Eleonora Giunchiglia [GitHub][LinkedIn][Google Scholar]
:memo: References
[1] Mihaela Catalina Stoian, Salijona Dyrmishi, Maxime Cordy, Thomas Lukasiewicz, Eleonora Giunchiglia. How Realistic Is Your Synthetic Data? Constraining Deep Generative Models for Tabular Data. arXiv:2402.04823. Accepted at the International Conference on Learning Representations (ICLR), 2024.
[2] Eleonora Giunchiglia, Alex Tatomir, Mihaela Catalina Stoian, Thomas Lukasiewicz. CCN+: A neuro-symbolic framework for deep learning with requirements. International Journal of Approximate Reasoning, 2024.
[3] Eleonora Giunchiglia and Thomas Lukasiewicz. Coherent Hierarchical Multi-Label Classification Networks. In Proceedings of Neural Information Processing Systems, 2020.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file pishield-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: pishield-1.0.0-py3-none-any.whl
- Upload date:
- Size: 37.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.12.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93fd5fc974d1fd39582d2ceb6da3f872d67981c3373882b8cb67530154092957 |
|
MD5 | 5e6e7d4c80eeac95baf80afca5a107c2 |
|
BLAKE2b-256 | 6949e69665777cd79af15da7540b40293c7f628101eb404c1959cc826e4e332a |