Token Labeling Toolbox for training image models
Project description
All Tokens Matter: Token Labeling for Training Better Vision Transformers (arxiv)
This is a Pytorch implementation of our paper.
Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.
Our codes are based on the pytorch-image-models by Ross Wightman.
Update
2021.7: Add script to generate label data.
2021.6: Support pip install tlt
to use our Token Labeling Toolbox for image models.
2021.6: Release training code and segmentation model.
2021.4: Release LV-ViT models.
LV-ViT Models
Model | layer | dim | Image resolution | Param | Top 1 | Download |
---|---|---|---|---|---|---|
LV-ViT-S | 16 | 384 | 224 | 26.15M | 83.3 | link |
LV-ViT-S | 16 | 384 | 384 | 26.30M | 84.4 | link |
LV-ViT-M | 20 | 512 | 224 | 55.83M | 84.0 | link |
LV-ViT-M | 20 | 512 | 384 | 56.03M | 85.4 | link |
LV-ViT-M | 20 | 512 | 448 | 56.13M | 85.5 | link |
LV-ViT-L | 24 | 768 | 448 | 150.47M | 86.2 | link |
LV-ViT-L | 24 | 768 | 512 | 150.66M | 86.4 | link |
Requirements
torch>=1.4.0 torchvision>=0.5.0 pyyaml scipy timm==0.4.5
data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Validation
Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path
CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint
Label data
We provide NFNet-F6 generated dense label map in Google Drive and BaiDu Yun (password: y6j2). As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.
Training
Train the LV-ViT-S:
If only 4 GPUs are available,
CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 /path/to/imagenet --model lvvit_s -b 256 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
If 8 GPUs are available:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
Train the LV-ViT-M and LV-ViT-L (run on 8 GPUs):
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_m -b 128 --apex-amp --img-size 224 --drop-path 0.2 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_l -b 128 --lr 1.e-3 --aa rand-n3-m9-mstd0.5-inc1 --apex-amp --img-size 224 --drop-path 0.3 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
If you want to train our LV-ViT on images with 384x384 resolution, please use --img-size 384 --token-label-size 24
.
Fine-tuning
To Fine-tune the pre-trained LV-ViT-S on images with 384x384 resolution:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 64 --apex-amp --img-size 384 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 24 --lr 5.e-6 --min-lr 5.e-6 --weight-decay 1.e-8 --finetune /path/to/checkpoint
To Fine-tune the pre-trained LV-ViT-S on other datasets without token labeling:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/dataset --model lvvit_s -b 64 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes $NUM_CLASSES --finetune /path/to/checkpoint
Segmentation
Our Segmentation model are fully based upon the MMSegmentation Toolkit. The model and config files are under seg/
folder which follow the same folder structure. You can simply drop in these file to get start.
git clone https://github.com/open-mmlab/mmsegmentation # and install
cp seg/mmseg/models/backbones/vit.py mmsegmentation/mmseg/models/backbones/
cp -r seg/configs/lvvit mmsegmentation/configs/
# test upernet+lvvit_s (add --aug-test to test on multi scale)
cd mmsegmentation
./tools/dist_test.sh configs/lvvit/upernet_lvvit_s_512x512_160k_ade20k.py /path/to/checkpoint 8 --eval mIoU [--aug-test]
Backbone | Method | Crop size | Lr Schd | mIoU | mIoU(ms) | Pixel Acc. | Param | Download |
---|---|---|---|---|---|---|---|---|
LV-ViT-S | UperNet | 512x512 | 160k | 47.9 | 48.6 | 83.1 | 44M | link |
LV-ViT-M | UperNet | 512x512 | 160k | 49.4 | 50.6 | 83.5 | 77M | link |
LV-ViT-L | UperNet | 512x512 | 160k | 50.9 | 51.8 | 84.1 | 209M | link |
Visualization
We apply the visualization method in this repo to visualize the parts of the image that led to a certain classification for DeiT-Base and our LV-ViT-S. The parts of the image that used by the network to make the decision are highlighted in red.
Label generation
To generate token label data for training:
python3 generate_label.py /path/to/imagenet/train /path/to/save/label_top5_train_nfnet --model dm_nfnet_f6 --pretrained --img-size 576 -b 32 --crop-pct 1.0
Reference
If you use this repo or find it useful, please consider citing:
@article{jiang2021all,
title={All Tokens Matter: Token Labeling for Training Better Vision Transformers},
author={Jiang, Zihang and Hou, Qibin and Yuan, Li and Zhou, Daquan and Shi, Yujun and Jin, Xiaojie and Wang, Anran and Feng, Jiashi},
journal={arXiv preprint arXiv:2104.10858},
year={2021}
}
Related projects
T2T-ViT, Re-labeling ImageNet, MMSegmentation, Transformer Explainability.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tlt-0.2.0.tar.gz
.
File metadata
- Download URL: tlt-0.2.0.tar.gz
- Upload date:
- Size: 33.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9cdaec3b07af206e6a3f2872cb883736e3f358994b077c248b71f711d1f18433 |
|
MD5 | 2b3bd0a63b1ef9be30e571211c3fd9aa |
|
BLAKE2b-256 | 5688576c55000f2c5eaf6e5f48340776e8373e63dd57d88d513d395832556413 |
File details
Details for the file tlt-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: tlt-0.2.0-py3-none-any.whl
- Upload date:
- Size: 33.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 722097c8af2fc5c45fc7eb0a363da2fac5b6c7cacfa92b982a96db5b4ec97b42 |
|
MD5 | 9f0b961919652e08d722fa94d46f3ab6 |
|
BLAKE2b-256 | fabcca153cf4796c6668091c6ca5c679ef8228f7db5d690fb3f2b2824e88b619 |