Skip to main content

Differentiable and gpu enabled fast wavelet transforms in PyTorch

Project description

GitHub Actions Documentation Status PyPI Versions PyPI - Project PyPI - License Black code style

Welcome to the PyTorch wavelet toolbox. This package implements:

  • the fast wavelet transform (fwt) implemented in wavedec.

  • the inverse fwt can be used by calling waverec.

  • the 2d fwt is called wavedec2

  • and inverse 2d fwt waverec2.

  • 1d sparse-matrix fast wavelet transforms with boundary filters.

  • 2d sparse-matrix transforms with separable & non-separable boundary filters (experimental).

  • single and two-dimensional wavelet packet forward transforms.

  • adaptive wavelet support (experimental).

This toolbox supports pywt-wavelets. Complete documentation is available: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/ptwt.html

Installation

Install the toolbox via pip or clone this repository. In order to use pip, type:

$ pip install ptwt

You can remove it later by typing pip uninstall ptwt.

Example usage:

Single dimensional transform

One way to compute fast wavelet transforms is to rely on padding and convolution. Consider the following example:

import torch
import numpy as np
import pywt
import ptwt  # use " from src import ptwt " if you cloned the repo instead of using pip.

# generate an input of even length.
data = np.array([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0])
data_torch = torch.from_numpy(data.astype(np.float32))
wavelet = pywt.Wavelet('haar')

# compare the forward fwt coefficients
print(pywt.wavedec(data, wavelet, mode='zero', level=2))
print(ptwt.wavedec(data_torch, wavelet, mode='zero', level=2))

# invert the fwt.
print(ptwt.waverec(ptwt.wavedec(data_torch, wavelet, mode='zero'), wavelet))

The functions wavedec and waverec compute the 1d-fwt and its inverse. Internally both rely on conv1d, and its transposed counterpart conv_transpose1d from the torch.nn.functional module. This toolbox supports discrete wavelets see also pywt.wavelist(kind='discrete'). I have tested Daubechies-Wavelets db-x and symlets sym-x, which are usually a good starting point.

Two-dimensional transform

Analog to the 1d-case wavedec2 and waverec2 rely on conv2d, and its transposed counterpart conv_transpose2d. To test an example run:

import ptwt, pywt, torch
import numpy as np
import scipy.misc

face = np.transpose(scipy.misc.face(),
                        [2, 0, 1]).astype(np.float64)
pytorch_face = torch.tensor(face).unsqueeze(1)
coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
                                level=2, mode="constant")
reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
np.max(np.abs(face - reconstruction.squeeze(1).numpy()))

Boundary Wavelets with Sparse-Matrices

In addition to convolution and padding approaches, sparse-matrix-based code with boundary wavelet support is available. In contrast to padding, boundary wavelets do not add extra pixels at the edges. Internally, boundary wavelet support relies on torch.sparse.mm. Generate 1d sparse matrix forward and backward transforms with the MatrixWavedec and MatrixWaverec classes. Reconsidering the 1d case, try:

import torch
import numpy as np
import pywt
import ptwt  # use " from src import ptwt " if you cloned the repo instead of using pip.

# generate an input of even length.
data = np.array([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0])
data_torch = torch.from_numpy(data.astype(np.float32))
# forward
matrix_wavedec = ptwt.MatrixWavedec(pywt.Wavelet("haar"), level=2)
coeff = matrix_wavedec(data_torch)
print(coeff)
# backward
matrix_waverec = ptwt.MatrixWaverec(pywt.Wavelet("haar"))
rec = matrix_waverec(coeff)
print(rec)

The process for the 2d transforms MatrixWavedec2, MatrixWaverec2 works similarly. By default, a non-separable transformation is used. To use a separable transformation, pass separable=True to MatrixWavedec2 and MatrixWaverec2. Separable transformations use a 1d transformation along both axes which might be faster since less matrix entries have to be orthogonalized.

Adaptive Wavelets

Experimental code to train an adaptive wavelet layer in PyTorch is available in the examples folder. In addition to static wavelets from pywt,

  • Adaptive product-filters

  • and optimizable orthogonal-wavelets are supported.

See https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/main/examples for a complete implementation.

Testing

The tests folder contains multiple tests to allow independent verification of this toolbox. After cloning the repository, and moving into the main directory, and installing tox with pip install tox run:

$ tox -e py

📖 Citation

If you find this work useful, please consider citing:

@phdthesis{handle:20.500.11811/9245,
  urn: https://nbn-resolving.org/urn:nbn:de:hbz:5-63361,
  author = {{Moritz Wolter}},
  title = {Frequency Domain Methods in Recurrent Neural Networks for Sequential Data Processing},
  school = {Rheinische Friedrich-Wilhelms-Universität Bonn},
  year = 2021,
  month = jul,
  url = {https://hdl.handle.net/20.500.11811/9245}
}

@thesis{Blanke2021,
  author = {Felix Blanke},
  title = {{Randbehandlung bei Wavelets für Faltungsnetzwerke}},
  type = {Bachelor's Thesis},
  annote = {Gbachelor},
  year = {2021},
  school = {Institut f\"ur Numerische Simulation, Universit\"at Bonn}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ptwt-0.1.0.tar.gz (36.9 kB view details)

Uploaded Source

Built Distribution

ptwt-0.1.0-py3-none-any.whl (47.0 kB view details)

Uploaded Python 3

File details

Details for the file ptwt-0.1.0.tar.gz.

File metadata

  • Download URL: ptwt-0.1.0.tar.gz
  • Upload date:
  • Size: 36.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.5

File hashes

Hashes for ptwt-0.1.0.tar.gz
Algorithm Hash digest
SHA256 39df9d0817573bba2eb7906eea1ca914d6fea1b7f4ca10c9dfb2e574857b00ae
MD5 97ee16317db39c82a5c496a042d7867d
BLAKE2b-256 ffddaf63d4b88c397656d29ef738da363ebdc9852136dfb13a2808bfbd6e6baa

See more details on using hashes here.

File details

Details for the file ptwt-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: ptwt-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 47.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.5

File hashes

Hashes for ptwt-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2d385ee7bd1376cc3fcebf0f64ebfa4f4f82f0a61dd8bd866f080ba5720186e2
MD5 df68ea10e90f7a8b71aae016df1c2d2a
BLAKE2b-256 d69e78d781f4062cd195464aeb337d529c26f7429b3ab280973e12bb98e8c7a9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page