Skip to main content

Vitax: Vision Transformers in Flax

Project description

VITAX

PyPI version

VITAX: An open-source platform for training and inference of Vision Transformers (ViT) with the new and elegant Flax NNX API.

This library provides a clean, from-scratch implementation of the Vision Transformer model and makes it easy to leverage powerful pretrained models from the Hugging Face Hub for your own computer vision tasks.

Core Features

  • Modern Flax API: Built entirely using flax.nnx, offering a more intuitive, object-oriented, and explicit way to build neural networks in JAX.
  • Hugging Face Integration: Seamlessly load pretrained ViT weights from google/vit-* models for transfer learning and fine-tuning.
  • Custom Models: Easily create and train Vision Transformer models from scratch with custom configurations.
  • Simple & Efficient Training: Includes a straightforward and JIT-compiled training and evaluation pipeline using optax for optimization.
  • Modular Design: The code is well-structured, separating the model definition, weight loading, and training logic for clarity and extensibility.

Installation

You can install vitax directly from PyPI:

pip install vitax

Model creation with Vitax

In Vitax, you can create vision transformers in different ways

Load a Pretrained Model for Fine-Tuning

This is the most common use case. You can load a model pretrained on ImageNet-21k and adapt its final classification layer for your specific dataset (e.g., CIFAR-100 with 100 classes).

from vitax.models import create_model

# Load a base pretrained ViT model and adapt it for 100 classes
model = create_model(
    'google/vit-base-patch16-224',
    num_classes=100,
    pretrained=True
)

Create a Model from Scratch (Random Weights)

If you want to train a model from the ground up, you can create one with random weights. You can either use a standard configuration or define your own.

Using a standard model configuration:

from vitax.models import create_model

# Create a 'vit-base-patch16-224' architecture with random weights
model = create_model(
    'google/vit-base-patch16-224',
    num_classes=10, # For a 10-class dataset like CIFAR-10
    pretrained=False
)

Using a fully custom architecture:

from vitax.models import create_model

# Define a custom configuration for a smaller model, compatible with HuggingFace ViT config
custom_config = {
    'image_size': 224,
    'patch_size': 16,
    'num_hidden_layers': 6,          # Fewer layers
    'num_attention_heads': 8,           # Fewer attention heads
    'intermediate_size': 500,          # Smaller MLP dimension
    'hidden_size': 128,       # Embedding dimension
}

# Create the custom model with random weights
custom_model = create_model(
    name_or_config=custom_config,
    num_classes=10,
    pretrained=False
)

Contributing

Contributions are welcome! If you find a bug or have a feature request, please open an issue on the GitHub repository.

License

This project is licensed under the MIT License. See the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vitax-0.0.10.tar.gz (10.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vitax-0.0.10-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file vitax-0.0.10.tar.gz.

File metadata

  • Download URL: vitax-0.0.10.tar.gz
  • Upload date:
  • Size: 10.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.3

File hashes

Hashes for vitax-0.0.10.tar.gz
Algorithm Hash digest
SHA256 67bb6559d230e4ae7f5b1be050b9a956dfa1de7a683c8a37ab5a49bbfbfe6b16
MD5 c67faba0f2eac322c65ea553a21d868b
BLAKE2b-256 8c2fbcdf9a920e7d6a8608e5b51d320bb3b48e0629a0bdc29dce6c5db2eedf9e

See more details on using hashes here.

File details

Details for the file vitax-0.0.10-py3-none-any.whl.

File metadata

  • Download URL: vitax-0.0.10-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.3

File hashes

Hashes for vitax-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 e5c0d43f291b83d771c0a8eb8a2dccfa4e25321f7eafa50a20a347d69e322b9e
MD5 aac4ec4abd44ec2152810856ab10f66c
BLAKE2b-256 ab7038b6fdc7a7b9f2b9f848e43b533ee5e8bf3197156eb232171ceb10db0762

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page