Building blocks for Continual Inference Networks in PyTorch
Project description
Continual Inference
Building blocks for Continual Inference Networks in PyTorch
Install
pip install continual-inference
Usage
import torch
from torch import nn
import continual as co
# B, C, T, H, W
example = torch.normal(mean=torch.zeros(5 * 3 * 3)).reshape((1, 1, 5, 3, 3))
# Acts as a drop-in replacement for torch.nn modules ✅
co_conv = co.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3))
nn_conv = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3))
co_conv.load_state_dict(nn_conv.state_dict()) # ensure identical weights
co_output = co_conv(example) # Same exact computation
nn_output = nn_conv(example) # Same exact computation
assert torch.equal(co_output, nn_output)
# But can also perform online inference efficiently 🚀
firsts = co_conv.forward_steps(example[:, :, :4])
last = co_conv.forward_step(example[:, :, 4])
assert torch.allclose(nn_output[:, :, : co_conv.delay], firsts)
assert torch.allclose(nn_output[:, :, co_conv.delay], last)
Continual Inference Networks (CINs)
Continual Inference Networks are a type of neural network, which operate on a continual input stream of data and infer a new prediction for each new time-step.
All networks and network-modules, that do not utilise temporal information can be used for an Online Inference Network (e.g. nn.Conv1d
and nn.Conv2d
on spatial data such as an image).
Moreover, recurrent modules (e.g. LSTM
and GRU
), that summarize past events in an internal state are also useable in CINs.
CIN:
O O O (output)
↑ ↑ ↑
LSTM LSTM LSTM (temporal LSTM)
↑ ↑ ↑
Conv2D Conv2D Conv2D (spatial 2D conv)
↑ ↑ ↑
I I I (input frame)
However, modules that operate on temporal data with the assumption that the more temporal context is available than the current frame (e.g. the spatio-temporal nn.Conv3d
used by many SotA video recognition models) cannot be directly applied.
Not CIN:
Θ (output)
↑
Conv3D (spatio-temporal 3D conv)
↑
----------------- (concatenate frames to clip)
↑ ↑ ↑
I I I (input frame)
Sometimes, though, the computations in such modules, can be cleverly restructured to work for online inference as well!
CIN:
O O Θ (output)
↑ ↑ ↑
ConvCo3D ConvCo3D ConvCo3D (continual spatio-temporal 3D conv)
↑ ↑ ↑
I I I (input frame)
Here, the ϴ
output of the Conv3D
and ConvCo3D
are identical! ✨
Modules
This repository contains online inference-friendly versions of common network building blocks, inlcuding:
-
(Temporal) convolutions:
co.Conv1d
co.Conv2d
co.Conv3d
-
(Temporal) batch normalisation:
co.BatchNorm2d
-
(Temporal) pooling:
co.AvgPool1d
co.AvgPool2d
co.AvgPool3d
co.MaxPool1d
co.MaxPool2d
co.MaxPool3d
co.AdaptiveAvgPool1d
co.AdaptiveAvgPool2d
co.AdaptiveAvgPool3d
co.AdaptiveMaxPool1d
co.AdaptiveMaxPool2d
co.AdaptiveMaxPool3d
-
Other
co.Sequential
- sequential wrapper for modulesco.Parallel
- parallel wrapper for modulesco.Residual
- residual wrapper for modulesco.Delay
- pure delay module
co.unsqueezed
- functional wrapper for non-continual modulesco.continual
- conversion function from non-continual modules to continual moduls
Continual Convolutions
Continual Convolutions can lead to major improvements in computational efficiency when online / frame-by-frame predictions are required.
Below, principle sketches comparing regular and continual convolutions are shown:
Regular Convolution. A regular temporal convolutional layer leads to redundant computations during online processing of video clips, as illustrated by the repeated convolution of inputs (green b,c,d) with a kernel (blue α,β) in the temporal dimen- sion. Moreover, prior inputs (b,c,d) must be stored be- tween time-steps for online processing tasks.
Continual Convolution. An input (green d or e) is convolved with a kernel (blue α, β). The intermediary feature-maps corresponding to all but the last temporal position are stored, while the last feature map and prior memory are summed to produce the resulting output. For a continual stream of inputs, Continual Convolutions produce identical outputs to regular convolutions.
For more information, we refer to the seminal paper on Continual Convolutions.
Forward modes
The library components feature three distinct forward modes, which are handy for different situations.
forward_step
Performs a forward computation for a single frame and continual states are updated accordingly. This is the mode to use for continual inference.
O+S O+S O+S O+S (O: output, S: updated internal state)
↑ ↑ ↑ ↑
N N N N (N: nework module)
↑ ↑ ↑ ↑
I I I I (I: input frame)
forward_steps
Performs a layer-wise forward computation using the continual module.
The computation is performed frame-by-frame and continual states are updated accordingly.
The output-input mapping the exact same as that of a regular module.
This mode is handy for initialising the network on a whole clip (multipleframes) before the forward
is usead for continual inference.
O (O: output)
↑
----------------- (-: aggregation)
O O+S O+S O+S O (O: output, S: updated internal state)
↑ ↑ ↑ ↑ ↑
N N N N N (N: nework module)
↑ ↑ ↑ ↑ ↑
P I I I P (I: input frame, P: padding)
forward
Performs a full forward computation exactly as the regular layer would. This method is handy for effient training on clip-based data.
O (O: output)
↑
N (N: nework module)
↑
----------------- (-: aggregation)
P I I I P (I: input frame, P: padding)
Compatibility
The library modules are built to integrate seamlessly with other PyTorch projects. Specifically, extra care was taken to ensure out-of-the-box compatibility with:
Projects
For full-fledged examples of complex Continual Inference Networks, see:
Citations
This library
@article{hedegaard2021colib,
title={Continual Inference Library},
author={Lukas Hedegaard},
journal={GitHub. Note: https://github.com/LukasHedegaard/continual-inference},
year={2021}
}
@article{hedegaard2021co3d,
title={Continual 3D Convolutional Neural Networks for Real-time Processing of Videos},
author={Lukas Hedegaard and Alexandros Iosifidis},
journal={preprint, arXiv:2106.00050},
year={2021}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for continual-inference-0.2.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2ee1a5c805d40cdd22a859423d9a12333c160c3e4e18091fd207af80db369439 |
|
MD5 | 4ba26759edb05ce0f1690e7d84994b4e |
|
BLAKE2b-256 | 310a80cc945ebe37dc7123d83538157321d90a517685efff80e20d1ebd65300e |
Hashes for continual_inference-0.2.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a22db9827e9b3a3e27371468e08495edcaa428c1a6d201628b23064e67de76ff |
|
MD5 | 86c59490f8601199d3516c9213dd0ff9 |
|
BLAKE2b-256 | 22b4ac719050c886958e0027a670a8ca0e8572a2f2f112413b47bdecd765e339 |