Flexible Vision Transformer (ViT) model for your needs.
Project description
Flexible Vision Transformer
A flexible PyTorch implementation of the Vision Transformer (ViT) model for image classification tasks, inspired by the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" by Dosovitskiy et al.
Overview
This repository provides a modular and customizable Vision Transformer (ViT) model that adapts the Transformer architecture for image classification. By treating an image as a sequence of patches, the model leverages self-attention mechanisms to capture global contextual relationships within the image.
Features
- Patch Embedding: Divides images into fixed-size patches and embeds them.
- Positional Embedding: Adds positional information to patch embeddings to retain spatial structure.
- Transformer Encoder Blocks: Utilizes multi-head self-attention and feed-forward networks with residual connections and layer normalization.
- Classification Head: Outputs class probabilities from the encoded features.
- Configurable Parameters: Easily adjust model dimensions, number of layers, attention heads, and more.
- Checkpointing: Save and load model checkpoints during training.
- Visualization: Utility functions to visualize image samples.
Installation
Clone the Repository and Install Dependencies
git clone https://github.com/T4ras123/Flexible-ViT.git
cd Flexible-ViT
pip install -r requirements.txt
Install via PyPI
pip install vision-transformer
Usage
Training the Model
Train the ViT model using the provided train.py
script with default parameters:
python train.py --data_path /path/to/dataset --epochs 100
Customizing Training Parameters
You can customize the training process by providing additional command-line arguments:
python train.py \
--data_path ./data \
--epochs 200 \
--learning_rate 0.0005 \
--batch_size 64 \
--image_size 224 \
--patch_size 16 \
--emb_dim 768 \
--n_layers 12 \
--heads 12 \
--dropout 0.1
Available Arguments
--data_path
: Path to the dataset.--epochs
: Number of training epochs.--learning_rate
: Learning rate for the optimizer.--batch_size
: Number of samples per batch.--image_size
: Dimension of input images (default: 144).--patch_size
: Size of each image patch (default: 4).--emb_dim
: Embedding dimension (default: 32).--n_layers
: Number of Transformer encoder layers (default: 6).--heads
: Number of attention heads (default: 2).--dropout
: Dropout rate (default: 0.1).
Loading a Saved Model
Load a previously saved model checkpoint:
import torch
from ViT.train import ViT, load_model
import torch.optim as optim
model = ViT(
ch=3,
img_size=224,
patch_size=16,
emb_dim=768,
n_layers=12,
out_dim=1000,
dropout=0.1,
heads=12
).to('cuda')
optimizer = optim.AdamW(model.parameters(), lr=0.0005)
epoch, loss = load_model(model, optimizer, 'ViT/models/vit_checkpoint.pt')
Evaluating the Model
Evaluate the trained model on the test dataset:
python evaluate.py --data_path /path/to/dataset --model_path ViT/models/vit_checkpoint.pt
Model Architecture
The Vision Transformer model consists of the following components:
- Patch Embedding: Converts input images into a sequence of flattened patch embeddings.
- Positional Embedding: Adds positional information to each patch embedding.
- Transformer Encoder Blocks: Comprises layers of multi-head self-attention and feed-forward networks with residual connections and layer normalization.
- Classification Head: Maps the encoded features to output class probabilities.
Key Components
PatchEmbedding
: Splits the image into patches and projects them into an embedding space.Attention
: Implements multi-head self-attention mechanisms.FeedForward
: A two-layer fully connected network with GELU activation and dropout.Block
: Combines attention and feed-forward layers with layer normalization and residual connections.ViT
: The main Vision Transformer model class that assembles all components.
Example Code
import torch
from ViT.train import ViT
model = ViT(
ch=3,
img_size=224,
patch_size=16,
emb_dim=768,
n_layers=12,
out_dim=1000,
dropout=0.1,
heads=12
)
inputs = torch.randn(1, 3, 224, 224)
outputs = model(inputs)
print(outputs.shape) # torch.Size([1, 1000])
Requirements
- Python ≥ 3.8
- PyTorch
- torchvision
- einops
- matplotlib
- numpy
Install Dependencies
pip install -r requirements.txt
License
This project is licensed under the MIT License - see the LICENSE file for details.
References
- Dosovitskiy et al., "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 2021.
- Vaswani et al., "Attention Is All You Need", 2017.
Citation
If you use this implementation in your research, please cite:
@misc{vision-transformer,
author = {vover},
title = {Flexible Vision Transformer Implementation},
year = {2024},
publisher = {vover},
journal = {GitHub repository},
howpublished = {\url{https://github.com/T4ras123/Flexible-ViT}},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vision_transformer-1.0.0.tar.gz
.
File metadata
- Download URL: vision_transformer-1.0.0.tar.gz
- Upload date:
- Size: 6.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0f94c3553be06244469d1ac85e384ab5824ac7f9644d980a5c79d505b36e772d |
|
MD5 | d56929b04f403fdd59eece22abf65aa4 |
|
BLAKE2b-256 | 68fdb56c4606da8b6afee3b3ce0f23aa410f79478e0418ae323be9603234ab09 |
File details
Details for the file vision_transformer-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: vision_transformer-1.0.0-py3-none-any.whl
- Upload date:
- Size: 6.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0a6db14c47af5257af98c6d3f58bd8d7601f8aa38132956e1b29db5a2f3e246b |
|
MD5 | 17dd8a95a5af2f7cb29b8732e7f3bffa |
|
BLAKE2b-256 | a51c3b5751ac464e76b81f4b68f2234b048e3faeb56560c1da5c77289a905867 |