Python module to train and test on CIFAR-10.
Project description
Description
This repository contains the Python package vitcifar10
, which is a Vision Transformer (ViT) baseline code for training and testing on CIFAR-10. This implementation only supports CUDA (no CPU training).
The idea of this repository is not to have a very flexible implementation, but one that can be used as a baseline for research with results on testing close to the state of the art.
The code in this repository relies on timm, which is a Python package with many state-of-the-art models implemented and pretrained.
Use within a Docker container
If you do not have Docker, here you can find a tutorial to install it.
-
Build
vitcifar10
Docker image:$ git clone https://github.com/luiscarlosgph/vitcifar10.git $ cd vitcifar10/docker $ docker build -t vitcifar10 .
-
Run
vitcifar10
container:$ docker run TODO
Install with pip
$ pip install vitcifar10
Install from source
$ git clone https://github.com/luiscarlosgph/vitcifar10.git
$ cd vitcifar10
$ python3 setup.py install
Train on CIFAR-10
-
Launch training:
$ python3 -m vitcifar10.train --lr 1e-4 --opt adam --nepochs 200 --bs 16 --cpdir checkpoints --logdir logs --cpint 5 --data ./data
-
Resume training from a checkpoint:
$ python3 -m vitcifar10.train --lr 1e-4 --opt adam --nepochs 200 --bs 16 --cpdir checkpoints --logdir logs --cpint 5 --data ./data --resume checkpoints/epoch_21.pt
-
Options:
--lr
: learning rate.--opt
: optimizer, choose eithersgd
oradam
.--nepochs
: number of epochs to execute.--bs
: training batch size.--cpdir
: checkpoint directory.--logdir
: path to the directory where the Tensorboard logs will be saved.--cpint
: interval of epochs to save a checkpoint of the training process.--resume
: path to the checkpoint file you want to resume.--data
: path to the directory where the dataset will be stored.
-
Launch Tensorboard:
$ python3 -m tensorboard.main --logdir logs --bind_all
Test on CIFAR-10
$ python3 -m vitcifar10.test --data ./data --resume checkpoints/model_best.pt
- Options:
--data
: path to the directory where the dataset will be stored.--resume
: path to the checkpoint file you want to test.
Author
Luis C. Garcia Peraza Herrera (luiscarlos.gph@gmail.com).
License
This repository is shared under an MIT license.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vitcifar10-0.0.1.tar.gz
.
File metadata
- Download URL: vitcifar10-0.0.1.tar.gz
- Upload date:
- Size: 10.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 039efac590a8f3f1c4dd9002d4395cdadbe867bf8c2d132f53796feecda093cb |
|
MD5 | b476ce94617fcff8d21f7bc61ad1518d |
|
BLAKE2b-256 | 5667deadb5d8e07bc0c3545c8946671838d7b189a2012cae814bf3009b94e3e7 |
File details
Details for the file vitcifar10-0.0.1-py3.9.egg
.
File metadata
- Download URL: vitcifar10-0.0.1-py3.9.egg
- Upload date:
- Size: 25.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ae22a3f49569a7008c8c00cfdc232b03c7946aa8ebd94955a1d61b452176cbe4 |
|
MD5 | 86680565a63b853b1117081012016998 |
|
BLAKE2b-256 | 21b29a2733dc1e6b42b9bd23bdf57f7e92c353736515ebeac402513607227a02 |