Skip to main content

Compute minimal Winograd convolution algorithms for convolutional neural networks

Project description

wincnn

PyPI version Python versions License

A simple python module for computing minimal Winograd convolution algorithms for use with convolutional neural networks as proposed in [1].

Installation

pip install wincnn

Requirements

  • Python >= 3.8
  • SymPy >= 1.9

Example: F(2,3)

For F(m,r) you must select m+r-2 polynomial interpolation points.

In this example we compute transforms for F(2,3) or F(2x2,3x3) using polynomial interpolation points (0,1,-1).

$ python3
>>> import wincnn
>>> wincnn.showCookToomFilter((0,1,-1), 2, 3)
AT = 
⎡1  1  1   0⎤
⎢           ⎥
⎣0  1  -1  1⎦

G = 
⎡ 1    0     0 ⎤
⎢              ⎥
⎢1/2  1/2   1/2⎥
⎢              ⎥
⎢1/2  -1/2  1/2⎥
⎢              ⎥
⎣ 0    0     1 ⎦

BT = 
⎡1  0   -1  0⎤
⎢            ⎥
⎢0  1   1   0⎥
⎢            ⎥
⎢0  -1  1   0⎥
⎢            ⎥
⎣0  -1  0   1⎦

AT*((G*g)(BT*d)) =
⎡d[0]⋅g[0] + d[1]⋅g[1] + d[2]⋅g[2]⎤
⎢                                 ⎥
⎣d[1]⋅g[0] + d[2]⋅g[1] + d[3]⋅g[2]⎦

The last matrix is the 1D convolution F(2,3) computed using the transforms AT, G, and BT, on 4 element signal d[0..3] and 3 element filter g[0..2], and serves to verify the correctness of the transforms. This is a symbolic computation, so the result should be exact.

Example: F(4,3)

The following example computes transforms for F(4,3).

>>> wincnn.showCookToomFilter((0,1,-1,2,-2), 4, 3)
AT = 
⎡1  1  1   1  1   0⎤
⎢                  ⎥
⎢0  1  -1  2  -2  0⎥
⎢                  ⎥
⎢0  1  1   4  4   0⎥
⎢                  ⎥
⎣0  1  -1  8  -8  1⎦

G = 
⎡1/4     0     0  ⎤
⎢                 ⎥
⎢-1/6  -1/6   -1/6⎥
⎢                 ⎥
⎢-1/6   1/6   -1/6⎥
⎢                 ⎥
⎢1/24  1/12   1/6 ⎥
⎢                 ⎥
⎢1/24  -1/12  1/6 ⎥
⎢                 ⎥
⎣ 0      0     1  ⎦

BT = 
⎡4  0   -5  0   1  0⎤
⎢                   ⎥
⎢0  -4  -4  1   1  0⎥
⎢                   ⎥
⎢0  4   -4  -1  1  0⎥
⎢                   ⎥
⎢0  -2  -1  2   1  0⎥
⎢                   ⎥
⎢0  2   -1  -2  1  0⎥
⎢                   ⎥
⎣0  4   0   -5  0  1⎦

AT*((G*g)(BT*d)) =
⎡d[0]⋅g[0] + d[1]⋅g[1] + d[2]⋅g[2]⎤
⎢                                 ⎥
⎢d[1]⋅g[0] + d[2]⋅g[1] + d[3]⋅g[2]⎥
⎢                                 ⎥
⎢d[2]⋅g[0] + d[3]⋅g[1] + d[4]⋅g[2]⎥
⎢                                 ⎥
⎣d[3]⋅g[0] + d[4]⋅g[1] + d[5]⋅g[2]⎦

Linear Convolution

If instead of an FIR filter you want the algorithm for linear convolution, all you have to do is exchange and transpose the data and inverse transform matrices. This is referred to as the Transfomation Principle.

>>> wincnn.showCookToomConvolution((0,1,-1),2,3)
A = 
⎡1  0 ⎤
⎢     ⎥
⎢1  1 ⎥
⎢     ⎥
⎢1  -1⎥
⎢     ⎥
⎣0  1 ⎦

G = 
⎡ 1    0     0 ⎤
⎢              ⎥
⎢1/2  1/2   1/2⎥
⎢              ⎥
⎢1/2  -1/2  1/2⎥
⎢              ⎥
⎣ 0    0     1 ⎦

B = 
⎡1   0  0   0 ⎤
⎢             ⎥
⎢0   1  -1  -1⎥
⎢             ⎥
⎢-1  1  1   0 ⎥
⎢             ⎥
⎣0   0  0   1 ⎦

Linear Convolution: B*((G*g)(A*d)) =
⎡      d[0]⋅g[0]      ⎤
⎢                     ⎥
⎢d[0]⋅g[1] + d[1]⋅g[0]⎥
⎢                     ⎥
⎢d[0]⋅g[2] + d[1]⋅g[1]⎥
⎢                     ⎥
⎣      d[1]⋅g[2]      ⎦

Example: F(6,3)

This example computes transform for F(6,3). We will use fraction interpolation points 1/2 and -1/2, so we use sympy.Rational in order to keep the symbolic computation exact (using floating point values would make the derivation of the transforms subject to rounding error).

>>> from sympy import Rational
>>> wincnn.showCookToomFilter((0,1,-1,2,-2,Rational(1,2),-Rational(1,2)), 6, 3)
AT = 
⎡1  1  1   1    1    1      1    0⎤
⎢                                 ⎥
⎢0  1  -1  2   -2   1/2   -1/2   0⎥
⎢                                 ⎥
⎢0  1  1   4    4   1/4    1/4   0⎥
⎢                                 ⎥
⎢0  1  -1  8   -8   1/8   -1/8   0⎥
⎢                                 ⎥
⎢0  1  1   16  16   1/16  1/16   0⎥
⎢                                 ⎥
⎣0  1  -1  32  -32  1/32  -1/32  1⎦

G = 
⎡ 1      0     0  ⎤
⎢                 ⎥
⎢-2/9  -2/9   -2/9⎥
⎢                 ⎥
⎢-2/9   2/9   -2/9⎥
⎢                 ⎥
⎢1/90  1/45   2/45⎥
⎢                 ⎥
⎢1/90  -1/45  2/45⎥
⎢                 ⎥
⎢ 32    16        ⎥
⎢ ──    ──    8/45⎥
⎢ 45    45        ⎥
⎢                 ⎥
⎢ 32   -16        ⎥
⎢ ──   ────   8/45⎥
⎢ 45    45        ⎥
⎢                 ⎥
⎣ 0      0     1  ⎦

BT = 
⎡1   0    -21/4    0    21/4     0    -1  0⎤
⎢                                          ⎥
⎢0   1      1    -17/4  -17/4    1    1   0⎥
⎢                                          ⎥
⎢0   -1     1    17/4   -17/4   -1    1   0⎥
⎢                                          ⎥
⎢0  1/2    1/4   -5/2   -5/4     2    1   0⎥
⎢                                          ⎥
⎢0  -1/2   1/4    5/2   -5/4    -2    1   0⎥
⎢                                          ⎥
⎢0   2      4    -5/2    -5     1/2   1   0⎥
⎢                                          ⎥
⎢0   -2     4     5/2    -5    -1/2   1   0⎥
⎢                                          ⎥
⎣0   -1     0    21/4     0    -21/4  0   1⎦

AT*((G*g)(BT*d)) =
⎡d[0]⋅g[0] + d[1]⋅g[1] + d[2]⋅g[2]⎤
⎢                                 ⎥
⎢d[1]⋅g[0] + d[2]⋅g[1] + d[3]⋅g[2]⎥
⎢                                 ⎥
⎢d[2]⋅g[0] + d[3]⋅g[1] + d[4]⋅g[2]⎥
⎢                                 ⎥
⎢d[3]⋅g[0] + d[4]⋅g[1] + d[5]⋅g[2]⎥
⎢                                 ⎥
⎢d[4]⋅g[0] + d[5]⋅g[1] + d[6]⋅g[2]⎥
⎢                                 ⎥
⎣d[5]⋅g[0] + d[6]⋅g[1] + d[7]⋅g[2]⎦

[1] "Fast Algorithms for Convolutional Neural Networks" Lavin and Gray, CVPR 2016. http://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Lavin_Fast_Algorithms_for_CVPR_2016_paper.pdf

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wincnn-2.0.1.tar.gz (10.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

wincnn-2.0.1-py3-none-any.whl (8.9 kB view details)

Uploaded Python 3

File details

Details for the file wincnn-2.0.1.tar.gz.

File metadata

  • Download URL: wincnn-2.0.1.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for wincnn-2.0.1.tar.gz
Algorithm Hash digest
SHA256 72bf08def7b3b9deb8950f68789babab5c96b6a867b0162ceb541ed105e3820d
MD5 3c66115846250b149720abcafb9ed722
BLAKE2b-256 ae5ddd7969ad0ba15e7c9d7dc5de63d82aff7870bec09551861667188d5d01a5

See more details on using hashes here.

File details

Details for the file wincnn-2.0.1-py3-none-any.whl.

File metadata

  • Download URL: wincnn-2.0.1-py3-none-any.whl
  • Upload date:
  • Size: 8.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for wincnn-2.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9e30a1f630d60930cb63f8755b8ff54fa4c1c0e083ae0f25d5ef14d86f095245
MD5 f7588ad3ea256e9780a7e536e7e8b034
BLAKE2b-256 24f8ec9305fd30b370a6f49a4e8e8a160caf2c6e233ca3816cf47e026ada21d2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page