Vision Transformers Zoo
Project description
A clean, extensible factory for creating HuggingFace-based Vision Transformer models (ViT, DeiT, DINO, DINOv2, DINOv3, CLIP) with flexible heads and easy backbone freezing.
Installation
pip install vit_zoo
From source:
git clone https://github.com/jbindaAI/vit_zoo.git
cd vit_zoo
pip install -e .
For development: pip install -e ".[dev]"
Quick start
from vit_zoo.factory import build_model
model = build_model("facebook/dinov2-base", head=10, freeze_backbone=True)
outputs = model(images)
logits = outputs["predictions"] # (batch_size, 10)
Basic usage
from vit_zoo.factory import build_model
# Simple classification - pass any HuggingFace model ID
model = build_model("google/vit-base-patch16-224", head=10, freeze_backbone=True)
outputs = model(images)
predictions = outputs["predictions"] # Shape: (batch_size, 10)
Custom MLP Head
from vit_zoo.factory import build_model
from vit_zoo.components import MLPHead
mlp_head = MLPHead(
input_dim=768,
hidden_dims=[512, 256],
output_dim=100,
dropout=0.1,
activation="gelu" # or 'relu', 'tanh', or nn.Module
)
model = build_model("facebook/dinov2-base", head=mlp_head)
Embedding Extraction
from transformers import CLIPVisionModel
model = build_model("openai/clip-vit-base-patch16", backbone_cls=CLIPVisionModel, head=None)
outputs = model(images, output_embeddings=True)
hidden_states = outputs["last_hidden_state"] # (batch_size, seq_len, embedding_dim)
cls_embedding = hidden_states[:, 0, :] # (batch_size, embedding_dim)
predictions = outputs["predictions"] # same as cls_embedding when head=None (IdentityHead)
Attention Weights
model = build_model(
"google/vit-base-patch16-224",
head=10,
config_kwargs={"attn_implementation": "eager"}
)
outputs = model(images, output_attentions=True)
attentions = outputs["attentions"]
Custom Head
from vit_zoo.components import BaseHead
import torch.nn as nn
class CustomHead(BaseHead):
def __init__(self, input_dim: int, num_classes: int):
super().__init__()
self._input_dim = input_dim
self.fc = nn.Linear(input_dim, num_classes)
@property
def input_dim(self) -> int:
return self._input_dim
def forward(self, embeddings):
return self.fc(embeddings)
head = CustomHead(input_dim=768, num_classes=10)
model = build_model("google/vit-base-patch16-224", head=head)
Multi-modal Models (CLIP)
For CLIP and other multi-modal models, pass backbone_cls to load only the vision encoder (AutoModel would load the full model):
from transformers import CLIPVisionModel
model = build_model("openai/clip-vit-base-patch16", backbone_cls=CLIPVisionModel, head=10)
Any HuggingFace Model
build_model uses AutoModel to auto-detect the model type from the HuggingFace Hub. Any ViT-compatible model works:
model = build_model("google/vit-large-patch16-224", head=10)
model = build_model("facebook/deit-base-distilled-patch16-224", head=10)
model = build_model("facebook/dinov2-with-registers-base", head=10)
API Reference
build_model()
build_model(
model_name: str,
head: Optional[Union[int, BaseHead]] = None,
backbone_cls: Optional[Type] = None,
freeze_backbone: bool = False,
load_pretrained: bool = True,
backbone_dropout: float = 0.0,
config_kwargs: Optional[Dict[str, Any]] = None,
) -> ViTModel
Parameters:
model_name: HuggingFace model identifier (e.g.,"google/vit-base-patch16-224","facebook/dinov2-base","openai/clip-vit-base-patch16"). Uses AutoModel to auto-detect model type.head:int(creates LinearHead),BaseHeadinstance, orNone(embedding extraction)backbone_cls: Optional HuggingFace model class (e.g.,CLIPVisionModel). Use for multi-modal models where AutoModel loads the full model.freeze_backbone: Freeze all backbone parametersconfig_kwargs: Extra config options (e.g.,{"attn_implementation": "eager"})
Usage:
build_model("google/vit-base-patch16-224", head=10)build_model("facebook/dinov2-base", head=None)(embedding extraction)
ViTModel.forward()
forward(
pixel_values: torch.Tensor,
output_attentions: bool = False,
output_embeddings: bool = False,
) -> Dict[str, Any]
Always returns a dict. Keys:
"predictions": head output tensor (always present)"attentions": optional, whenoutput_attentions=True"last_hidden_state": optional, whenoutput_embeddings=True; shape(batch_size, seq_len, embedding_dim)
ViTModel.freeze_backbone()
model.freeze_backbone(freeze: bool = True) # Freeze/unfreeze backbone
Supported Models
Any ViT-compatible model on the HuggingFace Hub works. Examples:
google/vit-base-patch16-224,google/vit-large-patch16-224(ViT)facebook/deit-base-distilled-patch16-224(DeiT)facebook/dino-vitb16(DINO)facebook/dinov2-base,facebook/dinov2-with-registers-base(DINOv2)facebook/dinov3-vitb16-pretrain-lvd1689m(DINOv3)openai/clip-vit-base-patch16(CLIP Vision; passbackbone_cls=CLIPVisionModel)
Browse the HuggingFace Hub for more models.
Import Patterns
from vit_zoo import ViTModel
from vit_zoo.factory import build_model
from vit_zoo.components import ViTBackbone, BaseHead, LinearHead, MLPHead, IdentityHead
Available Heads
LinearHead: Simple linear layer (auto-created whenhead=int)MLPHead: Multi-layer perceptron with configurable depth, activation, dropoutIdentityHead: Returns embeddings unchanged
All heads must implement input_dim property. Custom heads by subclassing BaseHead.
License
GPL-3.0
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vit_zoo-0.2.0.tar.gz.
File metadata
- Download URL: vit_zoo-0.2.0.tar.gz
- Upload date:
- Size: 12.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8647cbea8f48965505064ecfb35da4f19f752cbc8150efa8b49fc452eee85b43
|
|
| MD5 |
ba525c5a939a60fe92a7b2bcc884f3e2
|
|
| BLAKE2b-256 |
7c10862e54250d2bde63815d688be64bdb98f4536ba9230986743e34ea8e5958
|
Provenance
The following attestation bundles were made for vit_zoo-0.2.0.tar.gz:
Publisher:
release.yml on jbindaAI/vit_zoo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vit_zoo-0.2.0.tar.gz -
Subject digest:
8647cbea8f48965505064ecfb35da4f19f752cbc8150efa8b49fc452eee85b43 - Sigstore transparency entry: 938437526
- Sigstore integration time:
-
Permalink:
jbindaAI/vit_zoo@4c5d4c75b099661fe3a771f23bd10cecc880879f -
Branch / Tag:
refs/tags/0.2.0 - Owner: https://github.com/jbindaAI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@4c5d4c75b099661fe3a771f23bd10cecc880879f -
Trigger Event:
push
-
Statement type:
File details
Details for the file vit_zoo-0.2.0-py3-none-any.whl.
File metadata
- Download URL: vit_zoo-0.2.0-py3-none-any.whl
- Upload date:
- Size: 12.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7f5583eeea554b847ff80fea53212ccf2aedce522e8fae774cf27eb394248f40
|
|
| MD5 |
25a8b826d5f1ce3642467b3c030de370
|
|
| BLAKE2b-256 |
f5410c68e8c8be2dad73dc5be209f433b876542df32dc4fcef53516b0d7b1673
|
Provenance
The following attestation bundles were made for vit_zoo-0.2.0-py3-none-any.whl:
Publisher:
release.yml on jbindaAI/vit_zoo
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vit_zoo-0.2.0-py3-none-any.whl -
Subject digest:
7f5583eeea554b847ff80fea53212ccf2aedce522e8fae774cf27eb394248f40 - Sigstore transparency entry: 938437538
- Sigstore integration time:
-
Permalink:
jbindaAI/vit_zoo@4c5d4c75b099661fe3a771f23bd10cecc880879f -
Branch / Tag:
refs/tags/0.2.0 - Owner: https://github.com/jbindaAI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@4c5d4c75b099661fe3a771f23bd10cecc880879f -
Trigger Event:
push
-
Statement type: