Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Much faster than direct convolutions for large kernel sizes
Project description
FFT Conv PyTorch
This is a fork of original fft-conv-pytorch.
I made some modifications to support dilated and strided convolution, so it can be a drop-in-replacement of original PyTorch Conv*d
modules and conv*d
functions, with the same function parameters and behavior.
Install
pip install git+https://github.com/yoyololicon/fft-conv-pytorch
Example Usage
import torch
from torch_fftconv import fft_conv1d, FFTConv1d
# Create dummy data.
# Data shape: (batch, channels, length)
# Kernel shape: (out_channels, in_channels, kernel_size)
# Bias shape: (out channels, )
# For ordinary 1D convolution, simply set batch=1.
signal = torch.randn(3, 3, 1024 * 1024)
kernel = torch.randn(2, 3, 128)
bias = torch.randn(2)
# Functional execution. (Easiest for generic use cases.)
out = fft_conv1d(signal, kernel, bias=bias)
# Object-oriented execution. (Requires some extra work, since the
# defined classes were designed for use in neural networks.)
fft_conv = FFTConv1d(3, 2, 128, bias=True)
fft_conv.weight = torch.nn.Parameter(kernel)
fft_conv.bias = torch.nn.Parameter(bias)
out = fft_conv1d(signal)
Benchmarks
The best situation to use FFTConv
is when using large size kernel. The following image shows that when the size of input is fixed, the fft method remains an almost constant cost among all size of kernel, regardless.
For details and benchmarks on other parameters, check this notebook.
TODO
- Jittability.
- Dilated Convolution.
- Transposed Convolution.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_fftconv-0.1.3.tar.gz
.
File metadata
- Download URL: torch_fftconv-0.1.3.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.17
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4f647cd363778cc80d382ff13dc1b33e2291fcfc0d3ca482fde8936f3b9e9959 |
|
MD5 | 5733d1aa8f4b9d720bd923b46f75d53a |
|
BLAKE2b-256 | a59602582907088d2f448834e88325d8e3789fd135fbbb847cb5c896dbfd1ee1 |
File details
Details for the file torch_fftconv-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: torch_fftconv-0.1.3-py3-none-any.whl
- Upload date:
- Size: 7.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.17
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7425dcb55e1bfc5ee2e3143c94c9738dddf2bd7c2638b3bdbeb1a2a60e1bef37 |
|
MD5 | ec02c68233bad19e68a0a364194df55a |
|
BLAKE2b-256 | 66dfa380b0ecbc03272262001b47007b2fefc3a9fbebd24c8046c1962ceaf38a |