Skip to main content

JAX-native wavelet transforms

Project description

jaxwavelets

PyPI CI License: MIT Python

Extending PyWavelets to JAX. Differentiable, JIT-compilable, GPU-ready wavelet transforms.

Built on the mathematical foundations of PyWavelets and validated against it to machine precision. jaxwavelets brings the full PyWavelets API to JAX, enabling automatic differentiation, GPU acceleration, and composability with jax.vmap, jax.jit, and jax.pmap.

Features

Transform Functions
Discrete wavelet dwt, idwt, dwt2, idwt2, dwtn, idwtn
Multilevel wavedec2, waverec2, wavedecn, waverecn
Stationary (undecimated) swt, iswt, swt2, iswt2, swtn, iswtn
Continuous cwt, prepare_cwt, apply_cwt
Fully separable fswavedecn, fswaverecn
Multiresolution analysis mra, imra, mra2, imra2, mran, imran
Wavelet packets wp_decompose, wp_reconstruct, wp_decompose_nd, wp_reconstruct_nd
Thresholding threshold, threshold_firm
Utilities downcoef, upcoef, qmf, orthogonal_filter_bank

Wavelets: haar, db1-20, sym2-20, coif1-5, plus continuous wavelets (Morlet, Mexican hat, Gaussian 1-8, complex Gaussian 1-8, complex Morlet, Shannon, frequency B-spline).

Usage

import jax
import jax.numpy as jnp
import jaxwavelets as wt

# Decompose and reconstruct
x = jnp.ones((64, 64))
coeffs = wt.wavedecn(x, 'db4', level=3)
rec = wt.waverecn(coeffs, 'db4')

# Batch via vmap
from functools import partial
batch = jnp.ones((10, 64, 64))
batch_coeffs = jax.vmap(partial(wt.wavedecn, wavelet='db4', level=3))(batch)

# Differentiate through the transform
grad = jax.grad(lambda x: jnp.sum(wt.waverecn(wt.wavedecn(x, 'db4'), 'db4')))(x)

# JIT-compile for speed
fast = jax.jit(wt.wavedecn, static_argnames=['wavelet', 'mode', 'level'])
coeffs = fast(x, wavelet='db4', level=3)

Performance

JIT-compiled jaxwavelets on CPU vs PyWavelets C:

Transform                       pywt         jaxwavelets (JIT)    ratio
--------------------------------------------------------------------------
dwt 1D (N=4096)                  0.011ms       0.023ms       2.1x
wavedecn 1D (N=4096)             0.065ms       0.046ms       0.7x  ← faster
dwt2 (256x256)                   0.608ms       0.287ms       0.5x  ← faster
wavedecn 2D level=3              0.755ms       0.363ms       0.5x  ← faster
swt 1D level=3 (N=1024)          0.023ms       0.025ms       1.1x
cwt morl 6 scales (N=512)        0.316ms       0.139ms       0.4x  ← faster
cwt cmor 6 scales (N=512)        0.615ms       0.254ms       0.4x  ← faster

On top of this, jaxwavelets supports jax.grad, jax.vmap, jax.pmap, and GPU acceleration.

Installation

pip install jaxwavelets

No runtime dependency on PyWavelets. Filter coefficients are pre-extracted.

Testing

pip install pywt pytest
pytest jaxwavelets/tests/

1189 tests verify numerical agreement with PyWavelets to machine precision.

Composability

Every function operates on a single example. Batching, differentiation, compilation, and distribution compose naturally via JAX transforms:

import jaxwavelets as wt

# Batch over examples
jax.vmap(partial(wt.wavedecn, wavelet='db4'))(batch_of_fields)

# Per-example gradients
jax.vmap(jax.grad(loss_fn))(batch)

# Distribute across devices
jax.pmap(partial(wt.wavedecn, wavelet='db4'))(sharded_data)

# Nest arbitrarily
jax.jit(jax.vmap(jax.grad(
    lambda x: jnp.sum(wt.waverecn(wt.wavedecn(x, 'db4'), 'db4'))
)))(batch)

Coefficients are JAX pytrees, so jax.tree_util.tree_map works directly on them.

Design

  • Pure JAX — no numpy, no C extensions
  • Single-example functions — compose with jax.vmap/jax.pmap/jax.grad/jax.jit
  • Pytree coefficients — all outputs are JAX-compatible pytrees
  • Validated against PyWavelets — machine-precision numerical agreement

Acknowledgements

jaxwavelets extends the PyWavelets library to JAX. PyWavelets provides the mathematical reference implementation and filter coefficient database used for validation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxwavelets-0.1.10.tar.gz (42.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxwavelets-0.1.10-py3-none-any.whl (48.3 kB view details)

Uploaded Python 3

File details

Details for the file jaxwavelets-0.1.10.tar.gz.

File metadata

  • Download URL: jaxwavelets-0.1.10.tar.gz
  • Upload date:
  • Size: 42.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxwavelets-0.1.10.tar.gz
Algorithm Hash digest
SHA256 e547405ebc1e14cfc591aa265552a75f3e3db016c89d094bcabc7b3f08b7a6f0
MD5 4b999c0714f8c216bfea3edd32e52646
BLAKE2b-256 d96d547fe8474d6ce4ddcebe4a0d195f0430e9a1bc681bd44ba937ab224d2739

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxwavelets-0.1.10.tar.gz:

Publisher: ci.yml on handley-lab/jaxwavelets

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxwavelets-0.1.10-py3-none-any.whl.

File metadata

  • Download URL: jaxwavelets-0.1.10-py3-none-any.whl
  • Upload date:
  • Size: 48.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxwavelets-0.1.10-py3-none-any.whl
Algorithm Hash digest
SHA256 c278427781bf50b2a95ca2de29385d5bed788a185251eb86411973496ea26182
MD5 012de75c985b40e62933ce17ecd9d811
BLAKE2b-256 c20a3d53589b4fcfe3087777cc05da50e419efddac49509b077182638e65e3e4

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxwavelets-0.1.10-py3-none-any.whl:

Publisher: ci.yml on handley-lab/jaxwavelets

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page