Skip to main content

Token Labeling Toolbox for training image models

Project description

All Tokens Matter: Token Labeling for Training Better Vision Transformers (arxiv)

This is a Pytorch implementation of our paper.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Our codes are based on the pytorch-image-models by Ross Wightman.

Update

2021.7: Add script to generate label data.

2021.6: Support pip install tlt to use our Token Labeling Toolbox for image models.

2021.6: Release training code and segmentation model.

2021.4: Release LV-ViT models.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-M 20 512 448 56.13M 85.5 link
LV-ViT-L 24 768 448 150.47M 86.2 link
LV-ViT-L 24 768 512 150.66M 86.4 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml scipy timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map in Google Drive and BaiDu Yun (password: y6j2). As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Train the LV-ViT-S:

If only 4 GPUs are available,

CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 /path/to/imagenet --model lvvit_s -b 256 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

If 8 GPUs are available:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

Train the LV-ViT-M and LV-ViT-L (run on 8 GPUs):

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_m -b 128 --apex-amp --img-size 224 --drop-path 0.2 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_l -b 128 --lr 1.e-3 --aa rand-n3-m9-mstd0.5-inc1 --apex-amp --img-size 224 --drop-path 0.3 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema

If you want to train our LV-ViT on images with 384x384 resolution, please use --img-size 384 --token-label-size 24.

Fine-tuning

To Fine-tune the pre-trained LV-ViT-S on images with 384x384 resolution:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 64 --apex-amp --img-size 384 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 24 --lr 5.e-6 --min-lr 5.e-6 --weight-decay 1.e-8 --finetune /path/to/checkpoint

To Fine-tune the pre-trained LV-ViT-S on other datasets without token labeling:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/dataset --model lvvit_s -b 64 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes $NUM_CLASSES --finetune /path/to/checkpoint

Segmentation

Our Segmentation model are fully based upon the MMSegmentation Toolkit. The model and config files are under seg/ folder which follow the same folder structure. You can simply drop in these file to get start.

git clone https://github.com/open-mmlab/mmsegmentation # and install

cp seg/mmseg/models/backbones/vit.py mmsegmentation/mmseg/models/backbones/
cp -r seg/configs/lvvit mmsegmentation/configs/

# test upernet+lvvit_s (add --aug-test to test on multi scale)
cd mmsegmentation
./tools/dist_test.sh configs/lvvit/upernet_lvvit_s_512x512_160k_ade20k.py /path/to/checkpoint 8 --eval mIoU [--aug-test]
Backbone Method Crop size Lr Schd mIoU mIoU(ms) Pixel Acc. Param Download
LV-ViT-S UperNet 512x512 160k 47.9 48.6 83.1 44M link
LV-ViT-M UperNet 512x512 160k 49.4 50.6 83.5 77M link
LV-ViT-L UperNet 512x512 160k 50.9 51.8 84.1 209M link

Visualization

We apply the visualization method in this repo to visualize the parts of the image that led to a certain classification for DeiT-Base and our LV-ViT-S. The parts of the image that used by the network to make the decision are highlighted in red.

Compare

Label generation

To generate token label data for training:

python3 generate_label.py /path/to/imagenet/train /path/to/save/label_top5_train_nfnet --model dm_nfnet_f6 --pretrained --img-size 576 -b 32 --crop-pct 1.0

Reference

If you use this repo or find it useful, please consider citing:

@article{jiang2021all,
  title={All Tokens Matter: Token Labeling for Training Better Vision Transformers},
  author={Jiang, Zihang and Hou, Qibin and Yuan, Li and Zhou, Daquan and Shi, Yujun and Jin, Xiaojie and Wang, Anran and Feng, Jiashi},
  journal={arXiv preprint arXiv:2104.10858},
  year={2021}
}

Related projects

T2T-ViT, Re-labeling ImageNet, MMSegmentation, Transformer Explainability.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tlt-0.2.0.tar.gz (33.3 kB view details)

Uploaded Source

Built Distribution

tlt-0.2.0-py3-none-any.whl (33.8 kB view details)

Uploaded Python 3

File details

Details for the file tlt-0.2.0.tar.gz.

File metadata

  • Download URL: tlt-0.2.0.tar.gz
  • Upload date:
  • Size: 33.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6

File hashes

Hashes for tlt-0.2.0.tar.gz
Algorithm Hash digest
SHA256 9cdaec3b07af206e6a3f2872cb883736e3f358994b077c248b71f711d1f18433
MD5 2b3bd0a63b1ef9be30e571211c3fd9aa
BLAKE2b-256 5688576c55000f2c5eaf6e5f48340776e8373e63dd57d88d513d395832556413

See more details on using hashes here.

File details

Details for the file tlt-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: tlt-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 33.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6

File hashes

Hashes for tlt-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 722097c8af2fc5c45fc7eb0a363da2fac5b6c7cacfa92b982a96db5b4ec97b42
MD5 9f0b961919652e08d722fa94d46f3ab6
BLAKE2b-256 fabcca153cf4796c6668091c6ca5c679ef8228f7db5d690fb3f2b2824e88b619

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page