Skip to main content

Adaptive wavelets

Project description

Adaptive wavelets

Wavelets which adapt given data (and optionally a pre-trained model). This yields models which are faster, more compressible, and more interpretable.

Quickstart

Installation: pip install git+https://github.com/Yu-Group/adaptive-wavelets.git or clone the repo and run python setup.py install from the repo directory

Then, can use the core functions (see simplest example in notebooks/demo_simple_2d.ipynb or notebooks/demo_simple_1d.ipynb). See the docs for more information on arguments for these functions.

Given some data X, you can run the following:

from awave.utils.misc import get_wavefun
from awave.transform2d import DWT2d

wt = DWT2d(wave='db5', J=4)
wt.fit(X=X, lr=1e-1, num_epochs=10)  # this function alternatively accepts a dataloader
X_sparse = wt(X)  # uses the learned adaptive wavelet
phi, psi, x = get_wavefun(wt)  # can also inspect the learned adaptive wavelet

To distill a pretrained model named model, simply pass it as an additional argument to the fit function:

wt.fit(X=X, pretrained_model=model,
       lr=1e-1, num_epochs=10,
       lamL1attr=5) # control how much to regularize the model's attributions

Background

Official code for using / reproducing AWD from the paper "Adaptive wavelet distillation from neural networks through interpretations" (ha et al. 2021).

Recent deep-learning models have achieved impressive prediction performance, but often sacrifice interpretability and computational efficiency. Interpretability is crucial in many disciplines, such as science and medicine, where models must be carefully vetted or where interpretation is the goal itself. Moreover, interpretable models are concise and often yield computational efficiency. Here, we propose adaptive wavelet distillation (AWD), a method which aims to distill information from a trained neural network into a wavelet transform. Specifically, AWD penalizes feature attributions of a neural network in the wavelet domain to learn an effective multi-resolution wavelet transform. The resulting model is highly predictive, concise, computationally efficient, and has properties (such as a multi-scale structure) which make it easy to interpret. In close collaboration with domain experts, we showcase how AWD addresses challenges in two real-world settings: cosmological parameter inference and molecular-partner prediction. In both cases, AWD yields a scientifically interpretable and concise model which gives predictive performance better than state-of-the-art neural networks. Moreover, AWD identifies predictive features that are scientifically meaningful in the context of respective domains.

Also provides an implementation for "Learning Sparse Wavelet Representations"(recoskie & mann, 2018)

Related work

  • TRIM (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • ACD (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • CDEP (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

If this package is useful for you, please cite the following!

@article{ha2021adaptive,
  title={Adaptive wavelet distillation from neural networks through interpretations},
  author={Ha, Wooseok and Singh, Chandan and Lanusse, Francois and Song, Eli and Dang, Song and He, Kangmin and Upadhyayula, Srigokul and Yu, Bin},
  journal={arXiv preprint arXiv:2107.09145},
  year={2021}
}

Project details


Release history Release notifications | RSS feed

This version

0.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

awave-0.1.tar.gz (31.3 kB view details)

Uploaded Source

Built Distribution

awave-0.1-py3-none-any.whl (39.1 kB view details)

Uploaded Python 3

File details

Details for the file awave-0.1.tar.gz.

File metadata

  • Download URL: awave-0.1.tar.gz
  • Upload date:
  • Size: 31.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.7.0

File hashes

Hashes for awave-0.1.tar.gz
Algorithm Hash digest
SHA256 cacab1b9f27be1cf774eeb3e37649a190225bd5cf390a4feab889bd3e67ecda4
MD5 01597f12a71914e36c2e7fa37695929e
BLAKE2b-256 5febf8c770e0b76b5f7e0d5c4d4b7403e94362051c176c8c293f878d094b255b

See more details on using hashes here.

File details

Details for the file awave-0.1-py3-none-any.whl.

File metadata

  • Download URL: awave-0.1-py3-none-any.whl
  • Upload date:
  • Size: 39.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.7.0

File hashes

Hashes for awave-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5cfd112be81c1bb64bf9d02b78a57c441f797a3fddf6a3023231cd58a7bbe351
MD5 1b13adf5a8de304110ebb4d257e9c296
BLAKE2b-256 b543980bb5bc064db97d791c8089fed930ba2fa9d9c6089f2e288a72985cc100

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page