Bound propagation in Pytorch
Project description
Bound propagation
Linear and interval bound propagation in Pytorch with easy-to-use API, GPU support, and heavy parallization. Initially made as an alternative to the original CROWN implementation which featured only Numpy, lots of for-loops, and a cumbersome API.
To install:
pip install bound-propagation
Supported bound propagation methods:
For the examples below assume the following network definition:
from torch import nn
from bound_propagation import BoundModelFactory, HyperRectangle
class Network(nn.Sequential):
def __init__(self, *args):
if args:
# To support __get_index__ of nn.Sequential when slice indexing
# CROWN (and implicitly CROWN-IBP) is doing this underlying
super().__init__(*args)
else:
in_size = 30
classes = 10
super().__init__(
nn.Linear(in_size, 16),
nn.Tanh(),
nn.Linear(16, 16),
nn.Tanh(),
nn.Linear(16, classes)
)
net = Network()
factory = BoundModelFactory()
net = factory.build(net)
The method also works with nn.Sigmoid
and nn.ReLU
, and the three custom layers Residual
, Cat
, and Parallel
.
Interval bounds
To get interval bounds for either IBP, CROWN, or CROWN-IBP:
x = torch.rand(100, 30)
epsilon = 0.1
input_bounds = HyperRectangle.from_eps(x, epsilon)
ibp_bounds = net.ibp(input_bounds)
crown_bounds = net.crown(input_bounds).concretize()
crown_ibp_bounds = net.crown(input_bounds).concretize()
Linear bounds
To get linear bounds for either CROWN or CROWN-IBP:
x = torch.rand(100, 30)
epsilon = 0.1
input_bounds = HyperRectangle.from_eps(x, epsilon)
crown_bounds = net.crown(input_bounds)
crown_ibp_bounds = net.crown(input_bounds)
Authors
- Frederik Baymler Mathiesen - PhD student @ TU Delft
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
bound_propagation-0.2.6.tar.gz
(23.4 kB
view hashes)
Built Distribution
Close
Hashes for bound_propagation-0.2.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cb4a1411cd652cb501c54418028343e4834b976364c255169c0d6b585d43f14c |
|
MD5 | a86b9c531c84a4c944fd2ac709826111 |
|
BLAKE2b-256 | 5b7eaad2eda15804bfb1288976f1506a52280ce7a4721602fc2fd553d67f0bd4 |