Skip to main content

E(2)-Equivariant CNNs Library for PyTorch

Project description

General E(2)-Equivariant Steerable CNNs

Documentation | Experiments | Paper | Thesis

e2cnn is a PyTorch extension for equivariant deep learning.

Equivariant neural networks guarantee a specified transformation behavior of their feature spaces under transformations of their input. For instance, classical convolutional neural networks (CNNs) are by design equivariant to translations of their input. This means that a translation of an image leads to a corresponding translation of the network's feature maps. This package provides implementations of neural network modules which are equivariant under all isometries E(2) of the image plane my equation , that is, under translations, rotations and reflections. In contrast to conventional CNNs, E(2)-equivariant models are guaranteed to generalize over such transformations, and are therefore more data efficient.

The feature spaces of E(2)-Equivariant Steerable CNNs are defined as spaces of feature fields, being characterized by their transformation law under rotations and reflections. Typical examples are scalar fields (e.g. gray-scale images or temperature fields) or vector fields (e.g. optical flow or electromagnetic fields).

feature field examples

Instead of a number of channels, the user has to specify the field types and their multiplicities in order to define a feature space. Given a specified input- and output feature space, our R2conv module instantiates the most general convolutional mapping between them. Our library provides many other equivariant operations to process feature fields, including nonlinearities, mappings to produce invariant features, batch normalization and dropout. Feature fields are represented by GeometricTensor objects, which wrap a torch.Tensor with the corresponding transformation law. All equivariant operations perform a dynamic type-checking in order to guarantee a geometrically sound processing of the feature fields.

E(2)-Equivariant Steerable CNNs unify and generalize a wide range of isometry equivariant CNNs in one single framework. Examples include:

For more details we refer to our NeurIPS 2019 paper General E(2)-Equivariant Steerable CNNs.


The library is structured into four subpackages with different high-level features:

Component Description
e2cnn.group implements basic concepts of group and representation theory
e2cnn.kernels solves for spaces of equivariant convolution kernels
e2cnn.gspaces defines the image plane and its symmetries
e2cnn.nn contains equivariant modules to build deep neural networks

Demo

Since E(2)-steerable CNNs are equivariant under rotations and reflections, their inference is independent from the choice of image orientation. The visualization below demonstrates this claim by feeding rotated images into a randomly initialized E(2)-steerable CNN (left). The middle plot shows the equivariant transformation of a feature space, consisting of one scalar field (color-coded) and one vector field (arrows), after a few layers. In the right plot we transform the feature space into a comoving reference frame by rotating the response fields back (stabilized view).

Equivariant CNN output

The invariance of the features in the comoving frame validates the rotational equivariance of E(2)-steerable CNNs empirically. Note that the fluctuations of responses are discretization artifacts due to the sampling of the image on a pixel grid, which does not allow for exact continuous rotations.

For comparison, we show a feature map response of a conventional CNN for different image orientations below.

Conventional CNN output

Since conventional CNNs are not equivariant under rotations, the response varies randomly with the image orientation. This prevents CNNs from automatically generalizing learned patterns between different reference frames.

Experimental results

E(2)-steerable convolutions can be used as a drop in replacement for the conventional convolutions used in CNNs. Keeping the same training setup and without performing hyperparameter tuning, this leads to significant performance boosts compared to CNN baselines (values are test errors in percent):

model CIFAR-10 CIFAR-100 STL-10
CNN baseline 2.6   ± 0.1   17.1   ± 0.3   12.74 ± 0.23
E(2)-CNN * 2.39 ± 0.11 15.55 ± 0.13 10.57 ± 0.70
E(2)-CNN 2.05 ± 0.03 14.30 ± 0.09   9.80 ± 0.40

The models without * are for a fair comparison designed such that the number of parameters of the baseline is approximately preserved while models with * preserve the number of channels, and hence compute. For more details we refer to our paper.

Getting Started

e2cnn is easy to use since it provides a high level user interface which abstracts most intricacies of group and representation theory away. The following code snippet shows how to perform an equivariant convolution from an RGB-image to 10 regular feature fields (corresponding to a group convolution).

from e2cnn import gspaces                                          #  1
from e2cnn import nn                                               #  2
import torch                                                       #  3
                                                                   #  4
r2_act = gspaces.Rot2dOnR2(N=8)                                    #  5
feat_type_in  = nn.FieldType(r2_act,  3*[r2_act.trivial_repr])     #  6
feat_type_out = nn.FieldType(r2_act, 10*[r2_act.regular_repr])     #  7
                                                                   #  8
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5)       #  9
relu = nn.ReLU(feat_type_out)                                      # 10
                                                                   # 11
x = torch.randn(16, 3, 32, 32)                                     # 12
x = nn.GeometricTensor(x, feat_type_in)                            # 13
                                                                   # 14
y = relu(conv(x))                                                  # 15

Line 5 specifies the symmetry group action on the image plane my equation under which the network should be equivariant. We choose the cyclic group C8, which describes discrete rotations by multiples of 2π/8. Line 6 specifies the input feature field types. The three color channels of an RGB image are thereby to be identified as three independent scalar fields, which transform under the trivial representation of C8. Similarly, the output feature space is in line 7 specified to consist of 10 feature fields which transform under the regular representation of C8. The C8-equivariant convolution is then instantiated by passing the input and output type as well as the kernel size to the constructor (line 9). Line 10 instantiates an equivariant ReLU nonlinearity which will operate on the output field and is therefore passed the output field type.

Lines 12 and 13 generate a random minibatch of RGB images and wrap them into a nn.GeometricTensor to associate them with their correct field type. The equivariant modules process the geometric tensor in line 15. Each module is thereby checking whether the geometric tensor passed to them satisfies the expected transformation law.

Because the parameters do not need to be updated anymore at test time, after training, any equivariant network can be converted into a pure PyTorch model with no additional computational overhead in comparison to conventional CNNs. The code currently supports the automatic conversion of a few commonly used modules through the .export() method; check the documentation for more details.

A hands-on tutorial, introducing the basic functionality of e2cnn, is provided in introduction.ipynb. Code for training and evaluating a simple model on the rotated MNIST dataset is given in model.ipynb.

More complex equivariant Wide Resnet models are implemented in e2wrn.py. To try a model which is equivariant under reflections call:

cd examples
python e2wrn.py

A version of the same model which is simultaneously equivariant under reflections and rotations of angles multiple of 90 degrees can be run via:

python e2wrn.py --rot90

Dependencies

The library is based on Python3.7

torch>=1.1
numpy
scipy

Optional:

pymanopt
autograd

Check the branch legacy_py3.6 for a Python 3.6 compatible version of the library.

Installation

You can install the latest release as

pip install e2cnn

or you can clone this repository and manually install it with

pip install git+https://github.com/QUVA-Lab/e2cnn

Cite

The development of this library was part of the work done for our paper General E(2)-Equivariant Steerable CNNs. Please cite this work if you use our code:

@inproceedings{e2cnn,
    title={{General E(2)-Equivariant Steerable CNNs}},
    author={Weiler, Maurice and Cesa, Gabriele},
    booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
    year={2019},
}

Feel free to contact us.

License

e2cnn is distributed under BSD Clear license. See LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

e2cnn-0.1.7.tar.gz (119.7 kB view details)

Uploaded Source

Built Distribution

e2cnn-0.1.7-py3-none-any.whl (191.4 kB view details)

Uploaded Python 3

File details

Details for the file e2cnn-0.1.7.tar.gz.

File metadata

  • Download URL: e2cnn-0.1.7.tar.gz
  • Upload date:
  • Size: 119.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.25.1 setuptools/49.6.0.post20210108 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.7.3

File hashes

Hashes for e2cnn-0.1.7.tar.gz
Algorithm Hash digest
SHA256 f2e93c56d1082c2247c02e5a3347f3afa2a4829f1c2a8ea4cd43fa97508315fa
MD5 98a6169836ee9468e4df54a880e1b68e
BLAKE2b-256 8f604667c67d45087cea93473535e1e6da08dfdd0213cf7ecc19cf9ce1ddbd4a

See more details on using hashes here.

Provenance

File details

Details for the file e2cnn-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: e2cnn-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 191.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.25.1 setuptools/49.6.0.post20210108 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.7.3

File hashes

Hashes for e2cnn-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 0e7facd35f638e3d3083c93e2612968ca51a7c3c055d4f883043aa07598b6cbc
MD5 1e92872c7432f419da0136433764f86d
BLAKE2b-256 e5834f6191540a830fa072b5b85f73710195b9f88ece750a521539597aa77d6c

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page