Skip to main content

Vision Xformers

Project description

ViX

PWC

Vision Xformers: Efficient Attention for Image Classification

image

We use Linear Attention mechanisms to replace quadratic attention in ViT for image classification. We show that models using linear attention and CNN embedding layers need less parameters and low GPU requirements for achieving good accuracy. These improvements can be used to democratize the use of transformers by practitioners who are limited by data and GPU.

Hybrid ViX uses convolutional layers instead of linear layer for generating embeddings

Rotary Postion Embedding (RoPE) is also used in our models instead of 1D learnable position embeddings

Nomenclature: We replace the X in ViX with the starting alphabet of the attention mechanism used Eg. When we use Performer in ViX, we replace the X with P, calling it ViP (Vision Performer)

'Hybrid' prefix is used in models which uses convolutional layers instead of linear embeddding layer.

We have added RoPE in the title of models which used Rotary Postion Embedding

The code for using all for these models for classification of CIFAR 10/Tiny ImageNet dataset is provided

Models

  • Vision Transformer (ViT)
  • Vision Linformer (ViL)
  • Vision Performer (ViP)
  • Vision Nyströmformer (ViN)
  • FNet
  • Hybrid Vision Transformer (HybridViT)
  • Hybrid Vision Linformer (HybridViL)
  • Hybrid Vision Performer (HybridViP)
  • Hybrid Vision Nyströmformer (HybridViN)
  • Hybrid FNet
  • LeViN (Replacing Transformer in LeViT with Nyströmformer)
  • LeViP (Replacing Transformer in LeViT with Performer)
  • CvN (Replacing Transformer in CvT with Nyströmformer)
  • CvP (Replacing Transformer in CvT with Performer)
  • CCN (Replacing Transformer in CCT with Nyströmformer)
  • CCP(Replacing Transformer in CCT with Performer)

We have adapted the codes for ViT and linear transformers from @lucidrains

Install

$ pip install vision-xformer

Usage

Image Classification

Vision Nyströmformer (ViN)

import torch, vision_xformer
from vision_xformer import ViN

model = ViN(
    image_size = 32,
    patch_size = 1,
    num_classes = 10,             
    dim = 128,  
    depth = 4,             
    heads = 4,      
    mlp_dim = 256,
    num_landmarks = 256,
    pool = 'cls',
    channels = 3,
    dropout = 0.,
    emb_dropout = 0.
    dim_head = 32
)

img = torch.randn(1, 3, 32, 32)

preds = model(img) # (1, 10)

Vision Performer (ViP)

import torch, vision_xformer
from vision_xformer import ViP

model = ViP(
    image_size = 32,
    patch_size = 1,
    num_classes = 10,             
    dim = 128,  
    depth = 4,             
    heads = 4,      
    mlp_dim = 256,
    dropout = 0.25,
    dim_head = 32
)

img = torch.randn(1, 3, 32, 32)

preds = model(img) # (1, 10)

Vision Linformer (ViL)

import torch, vision_xformer
from vision_xformer import ViL

model = ViL(
    image_size = 32,
    patch_size = 1,
    num_classes = 10,             
    dim = 128,  
    depth = 4,             
    heads = 4,      
    mlp_dim = 256,
    dropout = 0.25,
    dim_head = 32
)

img = torch.randn(1, 3, 32, 32)

preds = model(img) # (1, 10)

Parameters

  • image_size: int.
    Size of input image. If you have rectangular images, make sure your image size is the maximum of the width and height
  • patch_size: int.
    Number of patches. image_size must be divisible by patch_size.
  • num_classes: int.
    Number of classes to classify.
  • dim: int.
    Final dimension of token emeddings after linear layer.
  • depth: int.
    Number of layers.
  • heads: int.
    Number of heads in multi-head attention
  • mlp_dim: int.
    Embedding dimension in the MLP (FeedForward) layer.
  • num_landmarks: int. Number of landmark points. Use one-fourth the number of patches.
  • pool: str. Pool type must be either 'cls' (cls token) or 'mean' (mean pooling)
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • dim_head: int.
    Embedding dimension of token in each head of mulit-head attention.

More information about these models can be obtained from our paper : ArXiv Paper, WACV 2022 Paper

If you wish to cite this, please use:

@misc{jeevan2021vision,
      title={Vision Xformers: Efficient Attention for Image Classification}, 
      author={Pranav Jeevan and Amit Sethi},
      year={2021},
      eprint={2107.02239},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@InProceedings{Jeevan_2022_WACV,
    author    = {Jeevan, Pranav and Sethi, Amit},
    title     = {Resource-Efficient Hybrid X-Formers for Vision},
    booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
    month     = {January},
    year      = {2022},
    pages     = {2982-2990}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vision_xformer-0.1.8.tar.gz (13.3 kB view details)

Uploaded Source

Built Distribution

vision_xformer-0.1.8-py3-none-any.whl (15.1 kB view details)

Uploaded Python 3

File details

Details for the file vision_xformer-0.1.8.tar.gz.

File metadata

  • Download URL: vision_xformer-0.1.8.tar.gz
  • Upload date:
  • Size: 13.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for vision_xformer-0.1.8.tar.gz
Algorithm Hash digest
SHA256 612d677c422186169ae75d5bc8fbd1c3f5f4a7758535eb0ee5d3af4000f28a88
MD5 71f7a10d728b8caafb42250a34e6036e
BLAKE2b-256 4c0c9f5964217930982774ea903d696be0fadef8411c6f8867af9600a47ab7a1

See more details on using hashes here.

File details

Details for the file vision_xformer-0.1.8-py3-none-any.whl.

File metadata

File hashes

Hashes for vision_xformer-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 91be9df5f1b4d8ebc4ae630ed14676181478afd02d9a7646a3d030053f24c5dc
MD5 4c799d042fbd7629adbd743139fc6451
BLAKE2b-256 7ec9be29d1e1fbc5f3312ad259972117270322eabe59b6eca6965b016de1e991

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page