Vision Transformers Zoo
Project description
Vision Transformer Zoo
A clean, extensible, and reusable factory for creating HuggingFace-based Vision Transformer models - including ViT, DeiT, DINO, DINOv2, DINOv3, and CLIP Vision — with flexible heads, easy backbone freezing, attention weight extraction, and seamless integration with PyTorch Lightning.
Features
- Easy model construction via
build_model(...)- create models in minutes - Flexible head support - Linear, MLP, or custom heads
- Common interface for different ViT flavors from HuggingFace
- Backbone freezing - freeze all backbone parameters
- Attention weights - easy extraction of attention weights
- Embedding extraction - get embeddings without classification head
- PyTorch Lightning ready - works seamlessly with Lightning modules
- Model registry - easy extensibility for new models
Installation
Local development install
git clone git@github.com:jbindaAI/vit_zoo.git
cd vit_zoo
pip install -e .
Quick Start
Example 1: Simple Classification (Lightning-ready)
from vit_zoo.factory import build_model
# Create a model with 10 classes, freeze backbone
model = build_model("vanilla_vit", head=10, freeze_backbone=True)
# Use in PyTorch Lightning
import pytorch_lightning as pl
class ViTLightningModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = build_model("vanilla_vit", head=10, freeze_backbone=True)
def forward(self, x):
return self.model(x)
Example 2: Custom MLP Head
from vit_zoo.factory import build_model
from vit_zoo.components import MLPHead
import torch.nn as nn
# Create MLP head with string activation (most common)
mlp_head = MLPHead(
input_dim=768, # Must match backbone embedding dimension
hidden_dims=[512, 256],
output_dim=100,
dropout=0.1,
activation="gelu" # String literal: 'relu', 'gelu', or 'tanh'
)
# Or use custom nn.Module activation
mlp_head_custom = MLPHead(
input_dim=768,
hidden_dims=[512, 256],
output_dim=100,
activation=nn.SiLU() # Any PyTorch activation module
)
model = build_model("dino_v2_vit", head=mlp_head)
Example 3: Embedding Extraction Only
from vit_zoo.factory import build_model
# No head - just extract embeddings
model = build_model("clip_vit", head=None)
outputs = model(images, output_embeddings=True)
embeddings = outputs["embeddings"] # Shape: (batch_size, embedding_dim)
Example 4: Attention Weights
from vit_zoo.factory import build_model
# For attention weights, you may need to set attn_implementation='eager'
model = build_model(
"vanilla_vit",
head=10,
config_kwargs={"attn_implementation": "eager"}
)
outputs = model(images, output_attentions=True)
predictions = outputs["predictions"] # Shape: (batch_size, num_classes)
attentions = outputs["attentions"] # Tuple of attention tensors (may be None if not supported)
Example 5: Custom Head Class
You can subclass BaseHead to create any custom head architecture:
from vit_zoo.factory import build_model
from vit_zoo.components import BaseHead
import torch.nn as nn
import torch
class MyCustomHead(BaseHead):
"""Custom head - can be MLP, UNET decoder, attention-based, etc."""
def __init__(self, input_dim: int, num_classes: int):
super().__init__()
self._input_dim = input_dim # Store for input_dim property
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, num_classes)
@property
def input_dim(self) -> int:
"""Returns the input dimension of the head."""
return self._input_dim
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
x = torch.relu(self.fc1(embeddings))
return self.fc2(x)
# Use custom head - input_dim will be validated automatically
head = MyCustomHead(input_dim=768, num_classes=10) # vanilla_vit has 768-dim embeddings
model = build_model("vanilla_vit", head=head) # Validates input_dim matches
Important: All custom heads must implement the input_dim property. The factory will automatically validate that the head's input_dim matches the backbone's embedding dimension, helping catch dimension mismatches early.
Example 6: Override Model Name
from vit_zoo.factory import build_model
# Use a different model variant from the registry default
model = build_model(
"vanilla_vit",
model_name="google/vit-large-patch16-224", # Override default
head=10
)
Example 7: Direct Usage (Any HuggingFace Model)
from vit_zoo.factory import build_model
from transformers import ViTModel
# Use any HuggingFace model directly without registry
model = build_model(
model_name="google/vit-base-patch16-224",
backbone_cls=ViTModel,
head=10
)
API Reference
build_model()
Main factory function to create ViT models.
build_model(
model_type: Optional[str] = None,
model_name: Optional[str] = None,
backbone_cls: Optional[Type[ViTBackboneProtocol]] = None,
head: Optional[Union[int, BaseHead]] = None,
freeze_backbone: bool = False,
load_pretrained: bool = True,
backbone_dropout: float = 0.0,
config_kwargs: Optional[Dict[str, Any]] = None,
) -> ViTModel
Parameters:
model_type: Optional registered model type ("vanilla_vit","deit_vit","dino_vit","dino_v2_vit","dinov2_reg_vit","dinov3_vit","clip_vit"). If provided, uses default backbone class and model name from registry. Whenmodel_typeis provided,backbone_clsis ignored and cannot be overridden.model_name: Optional HuggingFace model identifier. Ifmodel_typeis provided, this overrides the default model name from registry. Ifmodel_typeis not provided, this is required along withbackbone_cls.backbone_cls: Optional HuggingFace model class. Required ifmodel_typeis not provided. Ignored ifmodel_typeis provided (registry default is always used).head:int: CreatesLinearHeadwith that output dimensionBaseHead: Uses provided head instance. Automatically validates thathead.input_dimmatches the backbone's embedding dimension. Users can subclassBaseHeadto create custom heads (e.g., MLP, UNET decoder, attention-based, etc.). All custom heads must implement theinput_dimproperty.None: No head (embedding extraction mode)
freeze_backbone: Freeze all backbone parametersload_pretrained: Whether to load pretrained weightsbackbone_dropout: Dropout probability for backboneconfig_kwargs: Extra config options passed to model config or from_pretrained(). Can include 'attn_implementation' to control attention mechanism (e.g., 'eager' for attention weights, 'flash_attention_2', 'sdpa').
Returns: ViTModel instance
Usage patterns:
- Registry shortcut (recommended):
build_model("vanilla_vit", head=10) - Override default:
build_model("vanilla_vit", model_name="google/vit-large-patch16-224", head=10) - Direct usage:
build_model(model_name="custom/model", backbone_cls=CustomModel, head=10)
ViTModel.forward()
forward(
pixel_values: torch.Tensor,
output_attentions: bool = False,
output_embeddings: bool = False,
) -> Union[torch.Tensor, Dict[str, Any]]
Returns:
- If
output_attentions=Falseandoutput_embeddings=False: predictions tensor - If
output_attentions=Trueoroutput_embeddings=True: dict with keys:"predictions": Model predictions"attentions": Optional tuple of attention tensors"embeddings": Optional embeddings tensor
Supported Models
The following models are available in the registry with sensible defaults:
- vanilla_vit: Google ViT (default:
google/vit-base-patch16-224) - deit_vit: Facebook DeiT (default:
facebook/deit-base-distilled-patch16-224) - dino_vit: Facebook DINO (default:
facebook/dino-vitb16) - dino_v2_vit: Facebook DINOv2 without registers (default:
facebook/dinov2-base) - dinov2_reg_vit: Facebook DINOv2 with registers (default:
facebook/dinov2-with-registers-base) - dinov3_vit: Facebook DINOv3 (default:
facebook/dinov3-vitb16-pretrain-lvd1689m) - clip_vit: OpenAI CLIP Vision (default:
openai/clip-vit-base-patch16)
You can override the default model name or use any HuggingFace model directly (see examples above).
Advanced Usage
Using Any HuggingFace Model
You can use any HuggingFace Vision Transformer model directly without registering it:
from vit_zoo.factory import build_model
from transformers import ViTModel, DeiTModel
# Use any ViT variant
model = build_model(
model_name="google/vit-large-patch16-224",
backbone_cls=ViTModel,
head=10
)
# Use any DeiT variant
model = build_model(
model_name="facebook/deit-small-distilled-patch16-224",
backbone_cls=DeiTModel,
head=10
)
Adding Models to Registry
To add a model to the registry for convenience, you can modify the MODEL_REGISTRY in src/vit_zoo/factory.py:
from transformers import YourCustomModel
MODEL_REGISTRY.update({
"your_model": (YourCustomModel, "your-org/your-model-name"),
})
Available Heads
The library provides several built-in head implementations:
LinearHead: Simple linear transformation (created automatically when you pass anint)MLPHead: Multi-layer perceptron with configurable depth, activation, and dropoutIdentityHead: Returns embeddings unchanged (for embedding extraction)
Creating Custom Heads
You can create any custom head architecture by subclassing BaseHead. This is useful for:
- Complex MLP architectures
- UNET decoders
- Attention-based heads
- Multi-task heads
- Any other custom architecture
Example:
from vit_zoo.factory import build_model
from vit_zoo.components import BaseHead
import torch.nn as nn
import torch
class UNETDecoderHead(BaseHead):
"""Example: UNET-style decoder head."""
def __init__(self, input_dim: int, num_classes: int):
super().__init__()
self._input_dim = input_dim # Store for input_dim property
# Your custom architecture here
self.decoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
@property
def input_dim(self) -> int:
"""Returns the input dimension of the head."""
return self._input_dim
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
return self.decoder(embeddings)
# Use it - input_dim will be automatically validated
head = UNETDecoderHead(input_dim=768, num_classes=10)
model = build_model("vanilla_vit", head=head) # Validates input_dim matches
Important: All custom heads must implement the input_dim property. The factory automatically validates that the head's input_dim matches the backbone's embedding dimension, helping catch dimension mismatches early.
License
GPL-3.0
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vit_zoo-0.1.0.tar.gz.
File metadata
- Download URL: vit_zoo-0.1.0.tar.gz
- Upload date:
- Size: 15.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
766e820ba9eec7730701e6d0a723d82f7160f4e792b923984dd3b49ca9ad321f
|
|
| MD5 |
05dc0710d872195bd69be03669f3b197
|
|
| BLAKE2b-256 |
c4ffdd768dd6c55878ea9c0087841b13a3a9230584c67c9437a0034a9108c982
|
Provenance
The following attestation bundles were made for vit_zoo-0.1.0.tar.gz:
Publisher:
release.yml on jbindaAI/vit_zoo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vit_zoo-0.1.0.tar.gz -
Subject digest:
766e820ba9eec7730701e6d0a723d82f7160f4e792b923984dd3b49ca9ad321f - Sigstore transparency entry: 831914423
- Sigstore integration time:
-
Permalink:
jbindaAI/vit_zoo@2a8a77198d7cfa6cad2a0335597a004889fd75da -
Branch / Tag:
refs/tags/0.1.0 - Owner: https://github.com/jbindaAI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@2a8a77198d7cfa6cad2a0335597a004889fd75da -
Trigger Event:
push
-
Statement type:
File details
Details for the file vit_zoo-0.1.0-py3-none-any.whl.
File metadata
- Download URL: vit_zoo-0.1.0-py3-none-any.whl
- Upload date:
- Size: 14.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
87ed184bb26d4082ac73fed6b0521d34755395cead2e5cd01486316e252ef787
|
|
| MD5 |
35945d2c3d87d14786fe7fecc8c4813e
|
|
| BLAKE2b-256 |
775231b5beb0b07e40b24a1bb175255ac41b3a7a94186f4900c2b4788a364a92
|
Provenance
The following attestation bundles were made for vit_zoo-0.1.0-py3-none-any.whl:
Publisher:
release.yml on jbindaAI/vit_zoo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vit_zoo-0.1.0-py3-none-any.whl -
Subject digest:
87ed184bb26d4082ac73fed6b0521d34755395cead2e5cd01486316e252ef787 - Sigstore transparency entry: 831914439
- Sigstore integration time:
-
Permalink:
jbindaAI/vit_zoo@2a8a77198d7cfa6cad2a0335597a004889fd75da -
Branch / Tag:
refs/tags/0.1.0 - Owner: https://github.com/jbindaAI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@2a8a77198d7cfa6cad2a0335597a004889fd75da -
Trigger Event:
push
-
Statement type: