Skip to main content

WaveMix - Pytorch

Project description

WaveMix

PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC

Resource-efficient Token Mixing for Images using 2D Discrete Wavelet Transform

WaveMix Architecture

image

WaveMix-Lite

image

We propose WaveMix– a novel neural architecture for computer vision that is resource-efficient yet generalizable and scalable. WaveMix networks achieve comparable or better accuracy than the state-of-the-art convolutional neural networks, vision transformers, and token mixers for several tasks, establishing new benchmarks for segmentation on Cityscapes; and for classification on Places-365, f ive EMNIST datasets, and iNAT-mini. Remarkably, WaveMix architectures require fewer parameters to achieve these benchmarks compared to the previous state-of-the-art. Moreover, when controlled for the number of parameters, WaveMix requires lesser GPU RAM, which translates to savings in time, cost, and energy. To achieve these gains we used multi-level two-dimensional discrete wavelet transform (2D-DWT) in WaveMix blocks, which has the following advantages: (1) It reorganizes spatial information based on three strong image priors– scale-invariance, shift-invariance, and sparseness of edges, (2) in a lossless manner without adding parameters, (3) while also reducing the spatial sizes of feature maps, which reduces the memory and time required for forward and backward passes, and (4) expanding the receptive field faster than convolutions do. The whole architecture is a stack of self-similar and resolution-preserving WaveMix blocks, which allows architectural f lexibility for various tasks and levels of resource availability.

Task Dataset Metric Value
Semantic Segmentation Cityscapes Single-scale mIoU 82.70% (SOTA)
Image Classification ImageNet-1k Accuracy 74.93%

Parameter Efficiency

Task Model Parameters
99% Accu. in MNIST WaveMix Lite-8/10 3566
90% Accu. in Fashion MNIST WaveMix Lite-8/5 7156
80% Accu. in CIFAR-10 WaveMix Lite-32/7 37058
90% Accu. in CIFAR-10 WaveMix Lite-64/6 520106

The high parameter efficiency is obtained by replacing Deconvolution layers with Upsampling

This is an implementation of code from the following papers : Openreview Paper, ArXiv Paper 1, ArXiv Paper 2

Install

$ pip install wavemix

Usage

Semantic Segmentation

import torch, wavemix
from wavemix.SemSegment import WaveMix
import torch

model = WaveMix(
    num_classes= 20, 
    depth= 16,
    mult= 2,
    ff_channel= 256,
    final_dim= 256,
    dropout= 0.5,
    level=4,
    stride=2
)

img = torch.randn(1, 3, 256, 256)

preds = model(img) # (1, 20, 256, 256)

Image Classification

import torch, wavemix
from wavemix.classification import WaveMix
import torch

model = WaveMix(
    num_classes= 1000, 
    depth= 16,
    mult= 2,
    ff_channel= 192,
    final_dim= 192,
    dropout= 0.5,
    level=3,
    patch_size=4,
)
img = torch.randn(1, 3, 256, 256)

preds = model(img) # (1, 1000)

Single Image Super-resolution

import wavemix, torch
from wavemix.sisr import WaveMix

model = WaveMix(
    depth = 4,
    mult = 2,
    ff_channel = 144,
    final_dim = 144,
    dropout = 0.5,
    level=1,
)

img = torch.randn(1, 3, 256, 256)
out = model(img) # (1, 3, 512, 512)

To use a single Waveblock

import wavemix, torch
from wavemix import Level1Waveblock

Parameters

  • num_classes: int.
    Number of classes to classify/segment.
  • depth: int.
    Number of WaveMix blocks.
  • mult: int.
    Expansion of channels in the MLP (FeedForward) layer.
  • ff_channel: int.
    No. of output channels from the MLP (FeedForward) layer.
  • final_dim: int.
    Final dimension of output tensor after initial Conv layers. Channel dimension when tensor is fed to WaveBlocks.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • level: int.
    Number of levels of 2D wavelet transform to be used in Waveblocks. Currently supports levels from 1 to 4.
  • stride: int.
    Stride used in the initial convolutional layers to reduce the input resolution before being fed to Waveblocks.
  • initial_conv: str.
    Deciding between strided convolution or patchifying convolutions in the intial conv layer. Used for classification. 'pachify' or 'strided'.
  • patch_size: int.
    Size of each non-overlaping patch in case of patchifying convolution. Should be a multiple of 4.

Cite the following papers

@misc{
p2022wavemix,
title={WaveMix: Multi-Resolution Token Mixing for Images},
author={Pranav Jeevan P and Amit Sethi},
year={2022},
url={https://openreview.net/forum?id=tBoSm4hUWV}
}

@misc{jeevan2022wavemix,
    title={WaveMix: Resource-efficient Token Mixing for Images},
    author={Pranav Jeevan and Amit Sethi},
    year={2022},
    eprint={2203.03689},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

@misc{jeevan2023wavemix,
      title={WaveMix: A Resource-efficient Neural Network for Image Analysis}, 
      author={Pranav Jeevan and Kavitha Viswanathan and Anandu A S and Amit Sethi},
      year={2023},
      eprint={2205.14375},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wavemix-0.2.4.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

wavemix-0.2.4-py3-none-any.whl (11.6 kB view details)

Uploaded Python 3

File details

Details for the file wavemix-0.2.4.tar.gz.

File metadata

  • Download URL: wavemix-0.2.4.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for wavemix-0.2.4.tar.gz
Algorithm Hash digest
SHA256 ffcfbdb15ac64075e9466c11748ff3d5c0893a56c2b27ba72dc2d7499e46099e
MD5 57f2cced12f0c40df9e091ba5ed35b97
BLAKE2b-256 7a05816641d46281b06d745d4d4b35cd76409dd6baca9049fb439ee6c0fd114c

See more details on using hashes here.

File details

Details for the file wavemix-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: wavemix-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 11.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for wavemix-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 99660deafe6eb13e3a464b1d5e26a133b3accc72dae2592ecff075d59b8510b1
MD5 94720304d910b30948cf8ab3b2c6309b
BLAKE2b-256 ef94c48cc277716e532206649041e882ec359322dc9d8f04b04c9f8a6749d2f3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page