Skip to main content

JAX-native wavelet transforms

Project description

jaxwavelets

Extending PyWavelets to JAX. Differentiable, JIT-compilable, GPU-ready wavelet transforms.

Built on the mathematical foundations of PyWavelets and validated against it to machine precision. jaxwavelets brings the full PyWavelets API to JAX, enabling automatic differentiation, GPU acceleration, and composability with jax.vmap, jax.jit, and jax.pmap.

Features

Transform Functions
Discrete wavelet dwt, idwt, dwt2, idwt2, dwtn, idwtn
Multilevel wavedec2, waverec2, wavedecn, waverecn
Stationary (undecimated) swt, iswt, swt2, iswt2, swtn, iswtn
Continuous cwt, prepare_cwt, apply_cwt
Fully separable fswavedecn, fswaverecn
Multiresolution analysis mra, imra, mra2, imra2, mran, imran
Wavelet packets wp_decompose, wp_reconstruct, wp_decompose_nd, wp_reconstruct_nd
Thresholding threshold, threshold_firm
Utilities downcoef, upcoef, qmf, orthogonal_filter_bank

Wavelets: haar, db1-20, sym2-20, coif1-5, plus continuous wavelets (Morlet, Mexican hat, Gaussian 1-8, complex Gaussian 1-8, complex Morlet, Shannon, frequency B-spline).

Usage

import jax
import jax.numpy as jnp
import jaxwavelets as wt

# Decompose and reconstruct
x = jnp.ones((64, 64))
coeffs = wt.wavedecn(x, 'db4', level=3)
rec = wt.waverecn(coeffs, 'db4')

# Batch via vmap
from functools import partial
batch = jnp.ones((10, 64, 64))
batch_coeffs = jax.vmap(partial(wt.wavedecn, wavelet='db4', level=3))(batch)

# Differentiate through the transform
grad = jax.grad(lambda x: jnp.sum(wt.waverecn(wt.wavedecn(x, 'db4'), 'db4')))(x)

# JIT-compile for speed
fast = jax.jit(wt.wavedecn, static_argnames=['wavelet', 'mode', 'level'])
coeffs = fast(x, wavelet='db4', level=3)

Performance

JIT-compiled jaxwavelets on CPU vs PyWavelets C:

Transform                       pywt         jaxwavelets (JIT)    ratio
--------------------------------------------------------------------------
dwt 1D (N=4096)                  0.011ms       0.023ms       2.1x
wavedecn 1D (N=4096)             0.065ms       0.046ms       0.7x  ← faster
dwt2 (256x256)                   0.608ms       0.287ms       0.5x  ← faster
wavedecn 2D level=3              0.755ms       0.363ms       0.5x  ← faster
swt 1D level=3 (N=1024)          0.023ms       0.025ms       1.1x
cwt morl 6 scales (N=512)        0.316ms       0.139ms       0.4x  ← faster
cwt cmor 6 scales (N=512)        0.615ms       0.254ms       0.4x  ← faster

On top of this, jaxwavelets supports jax.grad, jax.vmap, jax.pmap, and GPU acceleration.

Installation

pip install jaxwavelets

No runtime dependency on PyWavelets. Filter coefficients are pre-extracted.

Testing

pip install pywt pytest
pytest jaxwavelets/tests/

1189 tests verify numerical agreement with PyWavelets to machine precision.

Composability

Every function operates on a single example. Batching, differentiation, compilation, and distribution compose naturally via JAX transforms:

import jaxwavelets as wt

# Batch over examples
jax.vmap(partial(wt.wavedecn, wavelet='db4'))(batch_of_fields)

# Per-example gradients
jax.vmap(jax.grad(loss_fn))(batch)

# Distribute across devices
jax.pmap(partial(wt.wavedecn, wavelet='db4'))(sharded_data)

# Nest arbitrarily
jax.jit(jax.vmap(jax.grad(
    lambda x: jnp.sum(wt.waverecn(wt.wavedecn(x, 'db4'), 'db4'))
)))(batch)

Coefficients are JAX pytrees, so jax.tree_util.tree_map works directly on them.

Design

  • Pure JAX — no numpy, no C extensions
  • Single-example functions — compose with jax.vmap/jax.pmap/jax.grad/jax.jit
  • Pytree coefficients — all outputs are JAX-compatible pytrees
  • Validated against PyWavelets — machine-precision numerical agreement

Acknowledgements

jaxwavelets extends the PyWavelets library to JAX. PyWavelets provides the mathematical reference implementation and filter coefficient database used for validation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxwavelets-0.1.0.tar.gz (42.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxwavelets-0.1.0-py3-none-any.whl (48.1 kB view details)

Uploaded Python 3

File details

Details for the file jaxwavelets-0.1.0.tar.gz.

File metadata

  • Download URL: jaxwavelets-0.1.0.tar.gz
  • Upload date:
  • Size: 42.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for jaxwavelets-0.1.0.tar.gz
Algorithm Hash digest
SHA256 faad6dfed4c744ae66a351102da9ce3b94f4535b2f523e3b3664c39d5c73ebfc
MD5 2a0243f1f3246827a06d92246b8b414a
BLAKE2b-256 21b7ad1b04cb8eeab553306c71a7a3651c708b82266d27f75a468af4873f3dbd

See more details on using hashes here.

File details

Details for the file jaxwavelets-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxwavelets-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 48.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for jaxwavelets-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2991f1c986b52d82b3b68c880338cc8126b4c4d816077199fa3850077adab2c5
MD5 d414c9cedc3b1230b869116d588f7220
BLAKE2b-256 4ee611b5f2c08664db917fe21b526bb7d8f0fc938bff6dc974a346b1205618d7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page