Skip to main content

Differentiable and gpu enabled fast wavelet transforms in PyTorch

Project description

Pytorch Wavelet Toolbox (ptwt)

GitHub Actions PyPI - Project

Welcome to the PyTorch (adaptive) wavelet toolbox. This package implements:

  • the fast wavelet transform (fwt) (wavedec)
  • the inverse fwt (waverec)
  • the 2d fwt wavedec2
  • the inverse 2d fwt waverec2.
  • single and two-dimensional wavelet packet forward transforms.
  • adaptive wavelet support (experimental).
  • sparse matrix fast wavelet transforms (experimental).

Installation

Install the toolbox via pip or clone this repository. In order to use pip, type:

$ pip install ptwt

You can remove it later by typing pip uninstall ptwt.

Example usage:

import torch
import numpy as np
import pywt
import ptwt  # from src import ptwt instead if you cloned the repo instead of using pip.

# generate an input of even length.
data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
data_torch = torch.from_numpy(data.astype(np.float32))
wavelet = pywt.Wavelet('haar')

# compare the forward fwt coefficients
print(pywt.wavedec(data, wavelet, mode='zero', level=2))
print(ptwt.wavedec(data_torch, wavelet, mode='zero', level=2))

# invert the fwt.
print(ptwt.waverec(ptwt.wavedec(data_torch, wavelet, mode='zero', level=2), wavelet))

Unit Tests

The tests folder contains multiple tests to allow independent verification of this toolbox. After cloning the repository, and moving into the main directory, and installing tox with pip install tox run:

$ tox -e py

Adaptive Wavelets (experimental)

Code to train an adaptive wavelet layer in PyTorch is available in the examples folder. In addition to static wavelets from pywt,

  • Adaptive product-filters
  • and optimizable orthogonal-wavelets are supported.

Sparse-Matrix-multiplication Transform (experimental).

In addition to convolution-based fwt implementations matrix-based code is available. Continuing the example above try:

# forward
coeff, fwt_matrix = ptwt.matrix_wavedec(data_torch, wavelet, level=2)
print(coeff)
# backward 
rec, ifwt_matrix = ptwt.matrix_waverec(coeff, wavelet, level=2)
print(rec)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ptwt-0.0.3.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

ptwt-0.0.3-py3-none-any.whl (18.9 kB view details)

Uploaded Python 3

File details

Details for the file ptwt-0.0.3.tar.gz.

File metadata

  • Download URL: ptwt-0.0.3.tar.gz
  • Upload date:
  • Size: 18.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.5

File hashes

Hashes for ptwt-0.0.3.tar.gz
Algorithm Hash digest
SHA256 a192531e57d1c76d9867be5d82bc27ac06b77d51f20a5df6441561c7d35fc21f
MD5 10dc8b65f69a8d747369874c2ced0205
BLAKE2b-256 7b2949094c9e50dfa41234b07c5d2f4210aeb7a7527b0ce3e0d6f0e0eee9ce2b

See more details on using hashes here.

File details

Details for the file ptwt-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: ptwt-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 18.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.5

File hashes

Hashes for ptwt-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 53227609b9949f217d4c8cfde989ca4a44b56ef07b752c50f0188dba22fb6f11
MD5 301a73e52312d3e62d4e21a92cd84ef1
BLAKE2b-256 b6d66e53f349f25a6033c39cd401d685a95cb36b0f360c3b16e1fa0cd67d1d0b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page