JAX-native wavelet transforms
Project description
jaxwavelets
Extending PyWavelets to JAX. Differentiable, JIT-compilable, GPU-ready wavelet transforms.
Built on the mathematical foundations of PyWavelets and validated against it to machine precision. jaxwavelets brings the full PyWavelets API to JAX, enabling automatic differentiation, GPU acceleration, and composability with jax.vmap, jax.jit, and jax.pmap.
Features
| Transform | Functions |
|---|---|
| Discrete wavelet | dwt, idwt, dwt2, idwt2, dwtn, idwtn |
| Multilevel | wavedec2, waverec2, wavedecn, waverecn |
| Stationary (undecimated) | swt, iswt, swt2, iswt2, swtn, iswtn |
| Continuous | cwt, prepare_cwt, apply_cwt |
| Fully separable | fswavedecn, fswaverecn |
| Multiresolution analysis | mra, imra, mra2, imra2, mran, imran |
| Wavelet packets | wp_decompose, wp_reconstruct, wp_decompose_nd, wp_reconstruct_nd |
| Thresholding | threshold, threshold_firm |
| Utilities | downcoef, upcoef, qmf, orthogonal_filter_bank |
Wavelets: haar, db1-20, sym2-20, coif1-5, plus continuous wavelets (Morlet, Mexican hat, Gaussian 1-8, complex Gaussian 1-8, complex Morlet, Shannon, frequency B-spline).
Usage
import jax
import jax.numpy as jnp
import jaxwavelets as wt
# Decompose and reconstruct
x = jnp.ones((64, 64))
coeffs = wt.wavedecn(x, 'db4', level=3)
rec = wt.waverecn(coeffs, 'db4')
# Batch via vmap
from functools import partial
batch = jnp.ones((10, 64, 64))
batch_coeffs = jax.vmap(partial(wt.wavedecn, wavelet='db4', level=3))(batch)
# Differentiate through the transform
grad = jax.grad(lambda x: jnp.sum(wt.waverecn(wt.wavedecn(x, 'db4'), 'db4')))(x)
# JIT-compile for speed
fast = jax.jit(wt.wavedecn, static_argnames=['wavelet', 'mode', 'level'])
coeffs = fast(x, wavelet='db4', level=3)
Performance
JIT-compiled jaxwavelets on CPU vs PyWavelets C:
Transform pywt jaxwavelets (JIT) ratio
--------------------------------------------------------------------------
dwt 1D (N=4096) 0.011ms 0.023ms 2.1x
wavedecn 1D (N=4096) 0.065ms 0.046ms 0.7x ← faster
dwt2 (256x256) 0.608ms 0.287ms 0.5x ← faster
wavedecn 2D level=3 0.755ms 0.363ms 0.5x ← faster
swt 1D level=3 (N=1024) 0.023ms 0.025ms 1.1x
cwt morl 6 scales (N=512) 0.316ms 0.139ms 0.4x ← faster
cwt cmor 6 scales (N=512) 0.615ms 0.254ms 0.4x ← faster
On top of this, jaxwavelets supports jax.grad, jax.vmap, jax.pmap, and GPU acceleration.
Installation
pip install jaxwavelets
No runtime dependency on PyWavelets. Filter coefficients are pre-extracted.
Testing
pip install pywt pytest
pytest jaxwavelets/tests/
1189 tests verify numerical agreement with PyWavelets to machine precision.
Composability
Every function operates on a single example. Batching, differentiation, compilation, and distribution compose naturally via JAX transforms:
import jaxwavelets as wt
# Batch over examples
jax.vmap(partial(wt.wavedecn, wavelet='db4'))(batch_of_fields)
# Per-example gradients
jax.vmap(jax.grad(loss_fn))(batch)
# Distribute across devices
jax.pmap(partial(wt.wavedecn, wavelet='db4'))(sharded_data)
# Nest arbitrarily
jax.jit(jax.vmap(jax.grad(
lambda x: jnp.sum(wt.waverecn(wt.wavedecn(x, 'db4'), 'db4'))
)))(batch)
Coefficients are JAX pytrees, so jax.tree_util.tree_map works directly on them.
Design
- Pure JAX — no numpy, no C extensions
- Single-example functions — compose with
jax.vmap/jax.pmap/jax.grad/jax.jit - Pytree coefficients — all outputs are JAX-compatible pytrees
- Validated against PyWavelets — machine-precision numerical agreement
Acknowledgements
jaxwavelets extends the PyWavelets library to JAX. PyWavelets provides the mathematical reference implementation and filter coefficient database used for validation.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxwavelets-0.1.10.tar.gz.
File metadata
- Download URL: jaxwavelets-0.1.10.tar.gz
- Upload date:
- Size: 42.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e547405ebc1e14cfc591aa265552a75f3e3db016c89d094bcabc7b3f08b7a6f0
|
|
| MD5 |
4b999c0714f8c216bfea3edd32e52646
|
|
| BLAKE2b-256 |
d96d547fe8474d6ce4ddcebe4a0d195f0430e9a1bc681bd44ba937ab224d2739
|
Provenance
The following attestation bundles were made for jaxwavelets-0.1.10.tar.gz:
Publisher:
ci.yml on handley-lab/jaxwavelets
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxwavelets-0.1.10.tar.gz -
Subject digest:
e547405ebc1e14cfc591aa265552a75f3e3db016c89d094bcabc7b3f08b7a6f0 - Sigstore transparency entry: 1316981379
- Sigstore integration time:
-
Permalink:
handley-lab/jaxwavelets@c71996a7dc611a82e1549e21220cb7d53a844267 -
Branch / Tag:
refs/tags/v0.1.10 - Owner: https://github.com/handley-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@c71996a7dc611a82e1549e21220cb7d53a844267 -
Trigger Event:
push
-
Statement type:
File details
Details for the file jaxwavelets-0.1.10-py3-none-any.whl.
File metadata
- Download URL: jaxwavelets-0.1.10-py3-none-any.whl
- Upload date:
- Size: 48.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c278427781bf50b2a95ca2de29385d5bed788a185251eb86411973496ea26182
|
|
| MD5 |
012de75c985b40e62933ce17ecd9d811
|
|
| BLAKE2b-256 |
c20a3d53589b4fcfe3087777cc05da50e419efddac49509b077182638e65e3e4
|
Provenance
The following attestation bundles were made for jaxwavelets-0.1.10-py3-none-any.whl:
Publisher:
ci.yml on handley-lab/jaxwavelets
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxwavelets-0.1.10-py3-none-any.whl -
Subject digest:
c278427781bf50b2a95ca2de29385d5bed788a185251eb86411973496ea26182 - Sigstore transparency entry: 1316981387
- Sigstore integration time:
-
Permalink:
handley-lab/jaxwavelets@c71996a7dc611a82e1549e21220cb7d53a844267 -
Branch / Tag:
refs/tags/v0.1.10 - Owner: https://github.com/handley-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@c71996a7dc611a82e1549e21220cb7d53a844267 -
Trigger Event:
push
-
Statement type: