Skip to main content

mega-vit - Pytorch

Project description

Multi-Modality

MegaVit

The open source implementation of the model from "Scaling Vision Transformers to 22 Billion Parameters"

Paper Link

Appreciation

  • Lucidrains
  • Agorians

Install

pip install mega-vit

Usage

  • Simple usage,
import torch
from mega_vit.main import MegaVit

v = MegaVit(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)
print(preds)
  • Hyperparams as stated in paper:
import torch
from mega_vit.main import MegaVit

v = ViT(
    image_size = 224,
    patch_size = 14,
    num_classes = 1000,
    dim = 6144,
    depth = 48,
    heads = 48,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)
print(preds)

Model Architecture

  • Regular vit with new parallel layers, QK(Query/Key)Normalization, and omitted biases.

Dataset Strategy

The paper trains ViT-22B on a version of the JFT dataset that has been extended to around 4 billion images. JFT is a large-scale dataset scraped from the internet, originally containing over 300 million images labeled with a hierarchical taxonomy of 30,000 categories.

The authors do not provide full details on how the dataset was extended from the original JFT to 4 billion images. However, the goal seems to be creating a larger and more diverse training set to support scaling up the model size. Pre-training on larger datasets enables learning more robust and generalizable visual representations.

The authors evaluate ViT-22B on a comprehensive set of 39 datasets covering various domains like image classification, dense prediction tasks, video, and fairness benchmarks. Using such a diverse evaluation suite allows them to thoroughly assess the scalability and transferability of ViT-22B across different domains and data distributions.

Below is a table summarizing some of the key datasets used in the paper:

Dataset Domain Images Classes
JFT (training set) Internet images ~4 billion 30,000
ImageNet Natural images 1.28M 1000
ImageNet-C Corrupted ImageNet images 1.28M 1000
ImageNet-R Hard ImageNet images 30K 200
ImageNet-A Adversarial ImageNet images 7.5K 200
ObjectNet Natural images 113K 113
Cifar-10 Tiny natural images 60K 10
Cifar-100 Tiny natural images 60K 100
ADE20K Scene parsing 25K 150
Kinetics-400 Human action videos 400K 400
CelebA Celeb faces 202K 40

License

MIT

Citations

@misc{2302.05442,
Author = {Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Patrick Collier and Alexey Gritsenko and Vighnesh Birodkar and Cristina Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Pavetić and Dustin Tran and Thomas Kipf and Mario Lučić and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
Title = {Scaling Vision Transformers to 22 Billion Parameters},
Year = {2023},
Eprint = {arXiv:2302.05442},
}

Todo

  • Add flash attention, with layernorm before attn, and then layernom for qk values,
  • Basic training script on CIFAR,
  • When using ViT-22B, similar to any large scale model, it is difficult to understand how the model arrived at a specific decision, which could lead to lack of trust and accountability. Add in a mechanism to backtrack
  • create logic to train the decoder for 300k steps with a batch size of 64 using Adam (Kingma and Ba, 2015) and clip the gradients to a global norm value of 0.05 to stabilize training. We linearly increase the learning rate for 2500 steps to 0.0002 (starting from 0) and then decay the learning rate with a cosine schedule (Loshchilov and Hutter, 2017) back to 0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mega_vit-0.0.4.tar.gz (6.7 kB view details)

Uploaded Source

Built Distribution

mega_vit-0.0.4-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file mega_vit-0.0.4.tar.gz.

File metadata

  • Download URL: mega_vit-0.0.4.tar.gz
  • Upload date:
  • Size: 6.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for mega_vit-0.0.4.tar.gz
Algorithm Hash digest
SHA256 34a5ab030cd8522df973a196ad3db31145e3cf7e327fb1847fee8975882025e9
MD5 60b6805fec0ac45f2c6fb96d75d44d02
BLAKE2b-256 c86d6cac8f682369b898e2d1a2ee770c356faa61f4e7695ea311e0280262d237

See more details on using hashes here.

File details

Details for the file mega_vit-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: mega_vit-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 6.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for mega_vit-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 a1dc9c669723e14c3c69d3bdd323ab83e3210db4cf2c3e10a9622311418578cf
MD5 13b66bb7f9b732809fbf5325f876ef1e
BLAKE2b-256 495dfe3a4307bac83dfcca4824b3066e4941e073c99151037890861a905bd34f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page