Vitax: Vision Transformers in Flax
Project description
VITAX
VITAX: An open-source platform for training and inference of Vision Transformers (ViT) with the new and elegant Flax NNX API.
This library provides a clean, from-scratch implementation of the Vision Transformer model and makes it easy to leverage powerful pretrained models from the Hugging Face Hub for your own computer vision tasks.
Core Features
- Modern Flax API: Built entirely using
flax.nnx, offering a more intuitive, object-oriented, and explicit way to build neural networks in JAX. - Hugging Face Integration: Seamlessly load pretrained ViT weights from
google/vit-*models for transfer learning and fine-tuning. - Custom Models: Easily create and train Vision Transformer models from scratch with custom configurations.
- Simple & Efficient Training: Includes a straightforward and JIT-compiled training and evaluation pipeline using
optaxfor optimization. - Modular Design: The code is well-structured, separating the model definition, weight loading, and training logic for clarity and extensibility.
Installation
You can install vitax directly from PyPI:
pip install vitax
Model creation with Vitax
In Vitax, you can create vision transformers in different ways
Load a Pretrained Model for Fine-Tuning
This is the most common use case. You can load a model pretrained on ImageNet-21k and adapt its final classification layer for your specific dataset (e.g., CIFAR-100 with 100 classes).
from vitax.models import create_model
# Load a base pretrained ViT model and adapt it for 100 classes
model = create_model(
'google/vit-base-patch16-224',
num_classes=100,
pretrained=True
)
Create a Model from Scratch (Random Weights)
If you want to train a model from the ground up, you can create one with random weights. You can either use a standard configuration or define your own.
Using a standard model configuration:
from vitax.models import create_model
# Create a 'vit-base-patch16-224' architecture with random weights
model = create_model(
'google/vit-base-patch16-224',
num_classes=10, # For a 10-class dataset like CIFAR-10
pretrained=False
)
Using a fully custom architecture:
from vitax.models import create_model
# Define a custom configuration for a smaller model, compatible with HuggingFace ViT config
custom_config = {
'image_size': 224,
'patch_size': 16,
'num_hidden_layers': 6, # Fewer layers
'num_attention_heads': 8, # Fewer attention heads
'intermediate_size': 500, # Smaller MLP dimension
'hidden_size': 128, # Embedding dimension
}
# Create the custom model with random weights
custom_model = create_model(
name_or_config=custom_config,
num_classes=10,
pretrained=False
)
Contributing
Contributions are welcome! If you find a bug or have a feature request, please open an issue on the GitHub repository.
License
This project is licensed under the MIT License. See the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vitax-0.0.10.tar.gz.
File metadata
- Download URL: vitax-0.0.10.tar.gz
- Upload date:
- Size: 10.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
67bb6559d230e4ae7f5b1be050b9a956dfa1de7a683c8a37ab5a49bbfbfe6b16
|
|
| MD5 |
c67faba0f2eac322c65ea553a21d868b
|
|
| BLAKE2b-256 |
8c2fbcdf9a920e7d6a8608e5b51d320bb3b48e0629a0bdc29dce6c5db2eedf9e
|
File details
Details for the file vitax-0.0.10-py3-none-any.whl.
File metadata
- Download URL: vitax-0.0.10-py3-none-any.whl
- Upload date:
- Size: 10.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e5c0d43f291b83d771c0a8eb8a2dccfa4e25321f7eafa50a20a347d69e322b9e
|
|
| MD5 |
aac4ec4abd44ec2152810856ab10f66c
|
|
| BLAKE2b-256 |
ab7038b6fdc7a7b9f2b9f848e43b533ee5e8bf3197156eb232171ceb10db0762
|