Adaptive Wavelet Transform
Project description
WavSpA: Wavelet Space Attention for Enhancing Transformer's Long Sequence Learning
Welcome to the official repository of the WavSpA paper. This innovative work introduces adaptive wavelet transform techniques coupled with Transformer models to excel at processing long sequences. The implementation is crafted using JAX alongside Flax for robustness and efficiency.
Installation
Setup your environment to run the models using the provided requirements.txt
:
$ pip install -r requirements.txt
$ pip install wavspa
Note: This codebase supports JAX version 0.3.13.
Our experiments
We demonstrate substantial gains in performance with WavSpA across various attention-based architectures. The framework offers three parametrization options:
- Adaptive Wavelet (AdaWavSpA)
- Orthogonal Adaptive Wavelet (OrthoWavSpA)
- Wavelet Lifting (LiftWavSpA)
Performance metrics are as follows:
Models | ListOps | Text | Retrieval | Image | Pathfinder | Avg | Avg (w/o r) |
---|---|---|---|---|---|---|---|
Transformer | 36.37 | 64.27 | 57.46 | 42.44 | 71.40 | 54.39 | 53.62 |
AdaWavSpA | 55.40 | 81.60 | 79.27 | 55.58 | 81.12 | 70.59 | 68.43 |
OrthoWavSpA | 45.95 | 81.63 | 71.52 | 49.29 | 81.13 | 65.90 | 64.50 |
LiftWavSpA | 42.95 | 75.63 | 56.45 | 42.48 | 81.73 | 59.85 | 60.70 |
--- | --- | --- | --- | --- | --- | --- | --- |
Longformer | 35.63 | 62.85 | 56.89 | 42.22 | 69.71 | 53.46 | 52.60 |
AdaWavSpA | 49.30 | 79.73 | 58.57 | 50.84 | 79.48 | 63.66 | 64.93 |
OrthoWavSpA | 39.45 | 78.41 | 79.93 | 49.93 | 79.47 | 54.96 | 54.96 |
LiftWavSpA | 39.40 | 78.00 | 53.27 | 40.95 | 75.80 | 57.48 | 58.54 |
--- | --- | --- | --- | --- | --- | --- | --- |
Linformer | 35.70 | 53.94 | 52.27 | 38.47 | 66.44 | 49.36 | 48.64 |
AdaWavSpA | 37.15 | 54.75 | 61.09 | 34.93 | 65.66 | 50.72 | 48.12 |
OrthoWavSpA | 38.05 | 56.93 | 60.25 | 39.45 | 65.35 | 52.01 | 49.95 |
LiftWavSpA | 37.30 | 54.43 | 70.73 | 34.66 | 63.49 | 52.12 | 47.47 |
--- | --- | --- | --- | --- | --- | --- | --- |
Linear Att. | 16.13 | 65.90 | 53.09 | 42.32 | 75.91 | 50.67 | 50.06 |
AdaWavSpA | 38.90 | 76.82 | 71.38 | 54.81 | 79.68 | 64.32 | 62.55 |
OrthoWavSpA | 39.55 | 79.45 | 69.65 | 49.93 | 78.09 | 55.86 | 55.86 |
LiftWavSpA | 38.35 | 73.39 | 54.06 | 44.39 | 74.46 | 56.93 | 57.65 |
--- | --- | --- | --- | --- | --- | --- | --- |
Performer | 18.01 | 65.40 | 53.82 | 42.77 | 77.05 | 51.41 | 50.81 |
AdaWavSpA | 46.05 | 80.93 | 71.16 | 52.06 | 77.17 | 65.47 | 64.05 |
OrthoWavSpA | 39.80 | 79.10 | 57.67 | 48.78 | 78.09 | 60.69 | 61.44 |
LiftWavSpA | 39.85 | 75.96 | 52.75 | 39.97 | 76.20 | 56.95 | 58.00 |
Example Usage
For implementation details, see lra_benchmarks/models/wavspa/wavspa_learn.py. The wavelet initialization and transformation processes are crucial:
def setup(self):
## db initialization
assert self.wlen % 2 == 0, "incompatible"
self.eps = 1e-4
if "lift" in self.wavelet:
self.adawave_est = self.param('adawave_est', nn.initializers.normal(stddev=0.02), (self.wlen, self.wav_dim), self.dtype)
self.adawave_pred = self.param('adawave_pred', nn.initializers.normal(stddev=0.02), (self.wlen, self.wav_dim), self.dtype)
elif "ortho" in self.wavelet:
L = int(self.wlen / 2)
S = jnp.zeros(shape=[2*L, 2*L], dtype=int)
i = jnp.asarray(range(2*L))
j = jnp.asarray(range(1, 2*L+1)) % (2*L)
S = S.at[i, j].set(1)
self.S = sparse.BCOO.fromdense(S, nse=2*L)
self.S_inv = jnp.linalg.inv(S)
self.thetas = self.param('thetas', nn.initializers.uniform(2*jnp.pi), (L, self.wav_dim), self.dtype)
elif "db" in self.wavelet:
self.adawave = self.param('adawave', db_init, (self.wlen, self.wav_dim), self.dtype)
elif "sin" in self.wavelet:
self.adawave = self.param('adawave', sin_init, (self.wlen, self.wav_dim), self.dtype)
else:
# default to daubechie wave, non trainable
self.adawave = db_init(key=None, shape=(self.wlen, self.wav_dim), dtype=self.dtype)
Then for forward and backward wavelet transform:
z = wavspa.wavedec_learn(x, wavelet, level=self.level)
for level in range(len(z)):
z[level] = nn.SelfAttention(num_heads=self.num_heads,
dtype=self.dtype,
qkv_features=self.qkv_dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.normal(stddev=1e-6),
use_bias=False,
broadcast_dropout=False,
dropout_rate=self.attention_dropout_rate,
decode=False)(z[level], deterministic=deterministic)
z = wavspa.waverec_learn(z, wavelet)[:,:inputs.shape[1],:]
Datasets
Access and instructions for LRA, D2A, and CodeXGlue datasets:
To execute a task, use the train_best.py script with the appropriate configurations:
PYTHONPATH="$(pwd)":"$PYTHONPATH" python lra_benchmarks/listops/train_best.py \
--config=lra_benchmarks/listops/configs/wavspa-exp0.py \
--model_dir=/tmp/listops \
--task_name=basic \
--data_dir=$HOME/lra_data/listops/
Citation
If you find out work useful, please cite our paper at:
@inproceedings{
zhuang2023wavspa,
title={WavSpA: Wavelet Space Attention for Boosting Transformers' Long Sequence Learning Ability},
author={Yufan Zhuang and Zihan Wang and Fangbo Tao and Jingbo Shang},
booktitle={UniReps: the First Workshop on Unifying Representations in Neural Models at NeurIPS 2023},
year={2023},
url={https://openreview.net/forum?id=yC6b3hqyf8}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file wavspa-0.0.1.tar.gz
.
File metadata
- Download URL: wavspa-0.0.1.tar.gz
- Upload date:
- Size: 16.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7e29ad1ccb5719dacbb8a0c5838e3e17725aebbf3a946c79510262e2e0f6f30c |
|
MD5 | a2c207f1a203dcecc3a890622aa32a19 |
|
BLAKE2b-256 | 6675fd1140fb12f5a9f945762e587564a9ad88ae31f16c7658e680817f0991d8 |
File details
Details for the file wavspa-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: wavspa-0.0.1-py3-none-any.whl
- Upload date:
- Size: 21.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a3e1286cae7d5e8300a01106a3ea353bea107a57a31a3f57bc0394fe773a7374 |
|
MD5 | 273d528ca6c5b816d0aee21835d68b9c |
|
BLAKE2b-256 | ef62cb4b1bed56bb6de5ae65a6e67bc36610c38d2504ee895257035f43e1aa15 |