Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.
Project description
fft-conv-pytorch
Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.
- Faster than direct convolution for large kernels.
- Much slower than direct convolution for small kernels.
- In my local tests, FFT convolution is faster when the kernel has >100 or so elements.
- Dependent on machine and PyTorch version.
Install
Using pip
:
pip install fft-conv-pytorch
From source:
git clone https://github.com/fkodom/fft-conv-pytorch.git
cd fft-conv-pytorch
pip install .
Example Usage
import torch
from fft_conv_pytorch import fft_conv, FFTConv1d
# Create dummy data.
# Data shape: (batch, channels, length)
# Kernel shape: (out_channels, in_channels, kernel_size)
# Bias shape: (out channels, )
# For ordinary 1D convolution, simply set batch=1.
signal = torch.randn(3, 3, 1024 * 1024)
kernel = torch.randn(2, 3, 128)
bias = torch.randn(2)
# Functional execution. (Easiest for generic use cases.)
out = fft_conv(signal, kernel, bias=bias)
# Object-oriented execution. (Requires some extra work, since the
# defined classes were designed for use in neural networks.)
fft_conv = FFTConv1d(3, 2, 128, bias=True)
fft_conv.weight = torch.nn.Parameter(kernel)
fft_conv.bias = torch.nn.Parameter(bias)
out = fft_conv(signal)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fft-conv-pytorch-1.1.1.tar.gz
(5.5 kB
view hashes)
Built Distribution
Close
Hashes for fft_conv_pytorch-1.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 51fd7b7b95b3f3fedcfb2e11a8175abed1d908aa02d6e35f6326d70e6f9beef2 |
|
MD5 | eb9b6fff3d0f96cabfe1d2e3b49946ef |
|
BLAKE2b-256 | 896c18e87745cf35691dd4f4b7d7ef8870719d22b46607c3c5cb5a58807acb28 |