AlgoVision - A Framework for Differentiable Algorithms and Algorithmic Supervision
Project description
AlgoVision - A Framework for Differentiable Algorithms and Algorithmic Supervision
This repository includes the official implementation of our NeurIPS 2021 Paper "Learning with Algorithmic Supervision via Continuous Relaxations" (Paper @ ArXiv, Video @ Youtube).
algovision
is a Python 3.6+ and PyTorch 1.9.0+ based library for making algorithms differentiable. It can be installed via:
pip install algovision
Applications include smoothly integrating algorithms into neural networks for algorithmic supervision, problem-specific optimization within an algorithm, and whatever your imagination allows.
Intro
Deriving a loss from a smooth algorithm can be as easy as
from examples import get_bubble_sort
import torch
torch.manual_seed(0)
# Get an array (the first dimension is the batch dimension, which is always required)
array = torch.randn(1, 8, requires_grad=True)
bubble_sort = get_bubble_sort(beta=5)
result, loss = bubble_sort(array)
loss.backward()
print(array)
print(result)
print(array.grad)
Here, the loss is a sorting loss corresponding to the number of swaps in the bubble sort algorithm. But we can also define this algorithm from scratch:
from algovision import (
Algorithm, Input, Output, Var, VarInt, # core
GT, IsTrue, # conditions
If, While, For, # control_structures
Let, LetInt, # functions
)
import torch
bubble_sort = Algorithm(
# Define the variables the input corresponds to
Input('array'),
# Declare and initialize all differentiable variables
Var('a', torch.tensor(0.)),
Var('b', torch.tensor(0.)),
Var('swapped', torch.tensor(1.)),
Var('loss', torch.tensor(0.)),
# Declare and initialize a hard integer variable (VarInt) for the control flow.
# It can be defined in terms of a lambda expression. The required variables
# are automatically inferred from the signature of the lambda expression.
VarInt('n', lambda array: array.shape[1] - 1),
# Start a relaxed While loop:
While(IsTrue('swapped'),
# Set `swapped` to 0 / False
Let('swapped', 0),
# Start an unrolled For loop. Corresponds to `for i in range(n):`
For('i', 'n',
# Set `a` to the `i`th element of `array`
Let('a', 'array', ['i']),
# Using an inplace lambda expression, we can include computations
# based on variables to obtain the element at position i+1.
Let('b', 'array', [lambda i: i+1]),
# An If-Else statement with the condition a > b
If(GT('a', 'b'),
if_true=[
# Set the i+1 th element of array to a
Let('array', [lambda i: i + 1], 'a'),
# Set the i th element of array to b
Let('array', ['i'], 'b'),
# Set swapped to 1 / True
Let('swapped', 1.),
# Increment the loss by 1 using a lambda expression
Let('loss', lambda loss: loss + 1.),
]
),
),
# Decrement the hard integer variable n by 1
LetInt('n', lambda n: n-1),
),
# Define what the algorithm should return
Output('array'),
Output('loss'),
# Set the inverse temperature beta
beta=5,
)
Instruction Set
The full set of modules is:
from algovision import (
Algorithm, Input, Output, Var, VarInt, # core
Eq, NEq, LT, LEq, GT, GEq, CatProbEq, CosineSimilarity, IsTrue, IsFalse, # conditions
If, While, For, # control_structures
Let, LetInt, Print, Min, ArgMin, Max, ArgMax, # functions
)
Algorithm
is the main class, Input
and Output
define arguments and return values, Var
defines differentiable variables and VarInt
defines non-differentiable integer variables.
Eq
, LT
, etc. are relaxed conditions for If
and While
, which are respective control structures.
For
bounded loops of fixed length that are unrolled.
Let
sets a differentiable variable, LetInt
sets a hard integer variable.
Note that hard integer variables should only be used if they are independent of the input values, but they may depend on the input shape (e.g., for reducing the number of iterations after each traversal of a For loop).
Print
prints for debug purposes.
Min
, ArgMin
, Max
, and ArgMax
return the element-wise min/max/argmin/argmax of a list of tensors (of equal shape).
Lambda Expressions
Key to defining an algorithm are lambda
expressions (see here for a reference).
They allow defining anonymous functions and therefore allow expressing computations in-place.
In most cases in algovision
, it is possible to write a value in terms of a lambda expressions.
The name of the used variable will be inferred from the signature of the expression.
For example, lambda x: x**2
will take the variable named x
and return the square of it at the location where the expression is written.
Let('z', lambda x, y: x**2 + y)
corresponds to the regular line of code z = x**2 + y
.
This also allows inserting complex external functions including neural networks as part of the lambda expression.
Assuming net
is a neural networks, one can write Let('y', lambda x: net(x))
(corresponding to y = net(x)
).
Let
Let
is a very flexible instruction.
In its most simple form Let
obtains two arguments, a string naming the variable where the result is written, and the value that may be expressed via a lambda
expression.
If the lambda expression returns multiple values, e.g., because a complex function is called and has two return values, the left argument can be a list of strings.
That is, Let(['a', 'b'], lamba x, y: (x+y, x-y))
corresponds to a, b = x+y, x-y
.
Let
also supports indexing. This is denoted by an additional list argument after the left and/or the right argument.
For example, Let('a', 'array', ['i'])
corresponds to a = array[i]
, while Let('array', ['i'], 'b')
corresponds to array[i] = b
.
Let('array', ['i'], 'array', ['j'])
corresponding to array[i] = array[j]
is also supported.
Note that indexing can also be expressed through lambda
expressions.
For example, Let('a', 'array', ['i'])
is equivalent to Let('a', lambda array, i: array[:, i])
. Note how in this case the batch dimension has to be explicitly taken into account ([:, ]
).
Relaxed indexing on the right-hand side is only supported through lambda
expressions due to its complexity.
Relaxed indexing on the left-hand side is supported if exactly one probability weight tensor is in the list (e.g., Let('array', [lambda x: get_weights(x)], 'a')
).
LetInt
only supports setting the variable to an integer (Python int
) or list of integers (as well as the same type via lambda expressions).
Note that hard integer variables should only be used if they are independent of the input values, but they may depend on the input shape.
We will make experiments available soon.
Citing
If you used our library, please cite it as
@inproceedings{petersen2021learning,
title={{Learning with Algorithmic Supervision via Continuous Relaxations}},
author={Petersen, Felix and Borgelt, Christian and Kuehne, Hilde and Deussen, Oliver},
booktitle={Proceedings of Neural Information Processing Systems (NeurIPS)},
year={2021}
}
License
algovision
is released under the MIT license. See LICENSE for additional details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file algovision-0.1.0.tar.gz
.
File metadata
- Download URL: algovision-0.1.0.tar.gz
- Upload date:
- Size: 17.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.0 importlib_metadata/4.5.0 pkginfo/1.8.2 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 055eee6c57579c0906fcf83dce963ad5138a15288d1fad3d833ca48be2b4ae1c |
|
MD5 | c96572f43f2fa1053de23dd001ec55dd |
|
BLAKE2b-256 | f6033630b48856e2ff217941a25fc9a5d3f18fa30ab330fe6fc685be30c74ac1 |
File details
Details for the file algovision-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: algovision-0.1.0-py3-none-any.whl
- Upload date:
- Size: 16.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.0 importlib_metadata/4.5.0 pkginfo/1.8.2 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5cd871d4153a422efc89c2b30f32fca6b5e926a77d809efc87f6a550722e5611 |
|
MD5 | 2259c2485b265ca693d7fa021329de16 |
|
BLAKE2b-256 | 3fdbebda8554cd9dec02e558d3877be1ae2795a96004555046a35f4c2a9e21b9 |