Differentiable and gpu enabled fast wavelet transforms in JAX
Project description
Differentiable and GPU enabled fast wavelet transforms in JAX.
Features
1d analysis and synthesis transforms are implemented in src/jaxlets/conv_fwt.py.
2d analysis and synthesis transforms are part of the src/jaxlets/conv_fwt_2d.py module.
Installation
Head to https://github.com/google/jax#installation and follow the procedure described there.
Transform Example:
import pywt
import numpy as np;
import jax.numpy as jnp
import src.jwt as jwt
# generate an input of even length.
data = jnp.array([0., 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0])
wavelet = pywt.Wavelet('haar')
# compare the forward fwt coefficients
print(pywt.wavedec(np.array(data), wavelet, mode='zero', level=2))
print(jwt.wavedec(data, wavelet, mode='zero', level=2))
# invert the fwt.
print(jwt.waverec(jwt.wavedec(data, wavelet, mode='zero', level=2), wavelet))
Testing
Unit tests are handled by tox. Clone the repository and run it with the following:
$ pip install tox
$ git clone https://github.com/v0lta/Jax-Wavelet-Toolbox
$ cd Jax-Wavelet-Toolbox
$ tox
Goals
In the spirit of jax the aim is to be 100% pywt compatible. Whenever possible, interfaces should be the same results identical.
64-Bit floating point numbers
To allow 64-bit precision numbers, a jax config flag must be set as shown below:
from jax.config import config
config.update("jax_enable_x64", True)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.