Skip to main content

Python module to train and test on CIFAR-10.

Project description

Description

This repository contains the Python package vitcifar10, which is a Vision Transformer (ViT) baseline code for training and testing on CIFAR-10. This implementation only supports CUDA (no CPU training).

The idea of this repository is not to have a very flexible implementation, but one that can be used as a baseline for research with results on testing close to the state of the art.

The code in this repository relies on timm, which is a Python package with many state-of-the-art models implemented and pretrained.

Use within a Docker container

If you do not have Docker, here you can find a tutorial to install it.

  1. Build vitcifar10 Docker image:

    $ git clone https://github.com/luiscarlosgph/vitcifar10.git
    $ cd vitcifar10/docker
    $ docker build -t vitcifar10 .
    
  2. Launch vitcifar10 container:

    $ docker run --name wild_vitcifar10 --runtime=nvidia -v /dev/shm:/dev/shm vitcifar10:latest &
    
  3. Get a terminal inside the container (you can execute this command multiple times to get multiple container terminals):

    $ docker exec -it wild_vitcifar10 /bin/zsh
    $ cd $HOME
    
  4. Launch CIFAR-10 training:

    $ python3 -m vitcifar10.train --lr 1e-4 --opt adam --nepochs 200 --bs 16 --cpdir checkpoints --logdir logs --cpint 5 --data ./data
    
  5. Launch CIFAR-10 testing:

    $ python3 -m vitcifar10.test --data ./data --resume checkpoints/model_best.pt
    

If you want to kill the container run $ docker kill wild_vitcifar10.

To remove it execute $ docker rm wild_vitcifar10.

Install with pip

$ pip install vitcifar10 --user

Install from source

$ git clone https://github.com/luiscarlosgph/vitcifar10.git
$ cd vitcifar10
$ python3 setup.py install

Train on CIFAR-10

  • Launch training:

    $ python3 -m vitcifar10.train --lr 1e-4 --opt adam --nepochs 200 --bs 16 --cpdir checkpoints --logdir logs --cpint 5 --data ./data
    
  • Resume training from a checkpoint:

    $ python3 -m vitcifar10.train --lr 1e-4 --opt adam --nepochs 200 --bs 16 --cpdir checkpoints --logdir logs --cpint 5 --data ./data --resume   checkpoints/epoch_21.pt
    
  • Options:

    • --lr: learning rate.
    • --opt: optimizer, choose either sgd or adam.
    • --nepochs: number of epochs to execute.
    • --bs: training batch size.
    • --cpdir: checkpoint directory.
    • --logdir: path to the directory where the Tensorboard logs will be saved.
    • --cpint: interval of epochs to save a checkpoint of the training process.
    • --resume: path to the checkpoint file you want to resume.
    • --data: path to the directory where the dataset will be stored.
    • --seed: fix random seed for reproducibility.
  • Launch Tensorboard:

    $ python3 -m  tensorboard.main --logdir logs --bind_all
    

Test on CIFAR-10

$ python3 -m vitcifar10.test --data ./data --resume checkpoints/model_best.pt --bs 1
  • Options:
    • --data: path to the directory where the dataset will be stored.
    • --resume: path to the checkpoint file you want to test.

Perform inference on a single image

After training, you can classify images such as this dog or this cat following:

$ python3 -m vitcifar10.inference --image data/dog.jpg --model checkpoints/model_best.pt 
It is a dog!
$ python3 -m vitcifar10.inference --image data/cat.jpg --model checkpoints/model_best.pt
It is a cat!

Training | validation | testing splits

This code uses torchvision to download and load the CIFAR-10 dataset. The constructor of the class torchvision.datasets.CIFAR10 has a boolean parameter called train.

In our code we set train=True to obtain the images for training and validation, using 90% for training (45K images) and 10% for validation (5K images). The validation set is used to discover the best model during training (could also be used for hyperparameter tunning or early stopping). For testing, we set train=False. The testing set contains 10K images.

Author

Luis C. Garcia Peraza Herrera (luiscarlos.gph@gmail.com).

License

This repository is shared under an MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vitcifar10-0.0.4.tar.gz (12.5 kB view details)

Uploaded Source

File details

Details for the file vitcifar10-0.0.4.tar.gz.

File metadata

  • Download URL: vitcifar10-0.0.4.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for vitcifar10-0.0.4.tar.gz
Algorithm Hash digest
SHA256 1ec426f3d38f3b944c7b2d24a1519c8d94a88a314f47999b51e5603042b02ff0
MD5 1028ddc2910c3a3680704e4006442317
BLAKE2b-256 64359d47a91c5cea5e136f12f3727f5fde118efe6fbb8588b194300df1f83cc5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page