Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.
Project description
fft-conv-pytorch
Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.
- Faster than direct convolution for large kernels.
- Much slower than direct convolution for small kernels.
- In my local tests, FFT convolution is faster when the kernel has >100 or so elements.
- Dependent on machine and PyTorch version.
Example Usage
import torch
from fft_conv_pytorch import fft_conv, FFTConv1d
# Create dummy data.
# Data shape: (batch, channels, length)
# Kernel shape: (out_channels, in_channels, kernel_size)
# Bias shape: (out channels, )
# For ordinary 1D convolution, simply set batch=1.
signal = torch.randn(3, 3, 1024 * 1024)
kernel = torch.randn(2, 3, 128)
bias = torch.randn(2)
# Functional execution. (Easiest for generic use cases.)
out = fft_conv(signal, kernel, bias=bias)
# Object-oriented execution. (Requires some extra work, since the
# defined classes were designed for use in neural networks.)
fft_conv = FFTConv1d(3, 2, 128, bias=True)
fft_conv.weight = torch.nn.Parameter(kernel)
fft_conv.bias = torch.nn.Parameter(bias)
out = fft_conv(signal)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for fft-conv-pytorch-1.0.0rc0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2ba94c8785cc570ab76f851522e1d25599c2aa33ca516e493b87bc570d4f4d96 |
|
MD5 | bae4ee1054ab1a387addf997882d4194 |
|
BLAKE2b-256 | 0f18a0e8c5cc19b6618f37ff985cb486ce5a48cf1636ddf64ae8b36ad7539c63 |
Close
Hashes for fft_conv_pytorch-1.0.0rc0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e2dc76d1dfe35320a7b0edb7bb301cb9108f8ec94fc1887bca295bea1207520f |
|
MD5 | 149423c9ef91e0bcdb5e81dc133a5f7b |
|
BLAKE2b-256 | 8c29c3e6977f8b086f2cc94556164039d51233828653291ead02482f896528e6 |