A vision transformer for training on MNIST
Project description
Python package mnistvit
A PyTorch-only implementation of a vision transformer (ViT) for training on MNIST, achieving 99.65% test accuracy with default parameters and without pre-training. The ViT architecture and learning parameters can be configured easily. Code for hyperparameter optimization is provided as well.
Requirements
The package requires Python 3.10 or greater and additionally requires the torch and
torchvision packages. For hyperparameter optimization, additionally ray[tune] and
optuna are required. The ViT itself requires torch only.
Installation
To install the mnistvit package, run the following command in the parent directory of the repository:
pip install mnistvit
Usage
To train a model with default parameters:
python -m mnistvit.train
The script will produce a file config.json with the model configuration and file
model.pt containing the trained model. Use the -h argument for a list of options.
To evaluate the test set accuracy of the model stored in model.pt with the
configuration in config.json:
python -m mnistvit.predict --use-accuracy
To predict the class of the digit stored in the file sample.jpg:
python -m mnistvit.predict --image-file sample.jpg
For hyperparameter optimization with default search parameters:
python -m mnistvit.tune
License
mnistvit is released under the GPLv3 license, as found in the LICENSE file.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mnistvit-1.3.1.tar.gz.
File metadata
- Download URL: mnistvit-1.3.1.tar.gz
- Upload date:
- Size: 24.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a6df414974d819bb9ac100a9a6a87746c335059ff8321c1a3296414a7ec2ee7e
|
|
| MD5 |
d79d2694019ceda429a87d306d80a529
|
|
| BLAKE2b-256 |
fadf8db335ae75933ef00d05e4903490db543fb6ca9359b6a9f88cc25b6880bc
|
File details
Details for the file mnistvit-1.3.1-py3-none-any.whl.
File metadata
- Download URL: mnistvit-1.3.1-py3-none-any.whl
- Upload date:
- Size: 26.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d3b7d24cafcefc1ca45b3ff11e75203a0bf0878a1d18cb26e1b54df9d99687dc
|
|
| MD5 |
17f7dab00e51eeea2f1023d8df2cae42
|
|
| BLAKE2b-256 |
c960517d8fea322864f767a24beec389a4206898866259fb02a094360881859c
|