Skip to main content

Vision Transformers Zoo

Project description

vit_zoo logo

PyPI Python Tests Source

A clean, extensible factory for creating HuggingFace-based Vision Transformer models (ViT, DeiT, DINO, DINOv2, DINOv3, CLIP) with flexible heads and easy backbone freezing.

Installation

pip install vit_zoo

From source:

git clone https://github.com/jbindaAI/vit_zoo.git
cd vit_zoo
pip install -e .

For development: pip install -e ".[dev]"

Quick start

from vit_zoo.factory import build_model

model = build_model("facebook/dinov2-base", head=10, freeze_backbone=True)
outputs = model(images)
logits = outputs["predictions"]  # (batch_size, 10)

Basic usage

from vit_zoo.factory import build_model

# Simple classification - pass any HuggingFace model ID
model = build_model("google/vit-base-patch16-224", head=10, freeze_backbone=True)
outputs = model(images)
predictions = outputs["predictions"]  # Shape: (batch_size, 10)

Custom MLP Head

from vit_zoo.factory import build_model
from vit_zoo.components import MLPHead

mlp_head = MLPHead(
    input_dim=768,
    hidden_dims=[512, 256],
    output_dim=100,
    dropout=0.1,
    activation="gelu"  # or 'relu', 'tanh', or nn.Module
)

model = build_model("facebook/dinov2-base", head=mlp_head)

Embedding Extraction

from transformers import CLIPVisionModel
model = build_model("openai/clip-vit-base-patch16", backbone_cls=CLIPVisionModel, head=None)
outputs = model(images, output_embeddings=True)
hidden_states = outputs["last_hidden_state"]  # (batch_size, seq_len, embedding_dim)
cls_embedding = hidden_states[:, 0, :]  # (batch_size, embedding_dim)
predictions = outputs["predictions"]  # same as cls_embedding when head=None (IdentityHead)

Attention Weights

model = build_model(
    "google/vit-base-patch16-224",
    head=10,
    config_kwargs={"attn_implementation": "eager"}
)
outputs = model(images, output_attentions=True)
attentions = outputs["attentions"]

Custom Head

from vit_zoo.components import BaseHead
import torch.nn as nn

class CustomHead(BaseHead):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self._input_dim = input_dim
        self.fc = nn.Linear(input_dim, num_classes)
    
    @property
    def input_dim(self) -> int:
        return self._input_dim
    
    def forward(self, embeddings):
        return self.fc(embeddings)

head = CustomHead(input_dim=768, num_classes=10)
model = build_model("google/vit-base-patch16-224", head=head)

Multi-modal Models (CLIP)

For CLIP and other multi-modal models, pass backbone_cls to load only the vision encoder (AutoModel would load the full model):

from transformers import CLIPVisionModel
model = build_model("openai/clip-vit-base-patch16", backbone_cls=CLIPVisionModel, head=10)

Any HuggingFace Model

build_model uses AutoModel to auto-detect the model type from the HuggingFace Hub. Any ViT-compatible model works:

model = build_model("google/vit-large-patch16-224", head=10)
model = build_model("facebook/deit-base-distilled-patch16-224", head=10)
model = build_model("facebook/dinov2-with-registers-base", head=10)

API Reference

build_model()

build_model(
    model_name: str,
    head: Optional[Union[int, BaseHead]] = None,
    backbone_cls: Optional[Type] = None,
    freeze_backbone: bool = False,
    load_pretrained: bool = True,
    backbone_dropout: float = 0.0,
    config_kwargs: Optional[Dict[str, Any]] = None,
) -> ViTModel

Parameters:

  • model_name: HuggingFace model identifier (e.g., "google/vit-base-patch16-224", "facebook/dinov2-base", "openai/clip-vit-base-patch16"). Uses AutoModel to auto-detect model type.
  • head: int (creates LinearHead), BaseHead instance, or None (embedding extraction)
  • backbone_cls: Optional HuggingFace model class (e.g., CLIPVisionModel). Use for multi-modal models where AutoModel loads the full model.
  • freeze_backbone: Freeze all backbone parameters
  • config_kwargs: Extra config options (e.g., {"attn_implementation": "eager"})

Usage:

  • build_model("google/vit-base-patch16-224", head=10)
  • build_model("facebook/dinov2-base", head=None) (embedding extraction)

ViTModel.forward()

forward(
    pixel_values: torch.Tensor,
    output_attentions: bool = False,
    output_embeddings: bool = False,
) -> Dict[str, Any]

Always returns a dict. Keys:

  • "predictions": head output tensor (always present)
  • "attentions": optional, when output_attentions=True
  • "last_hidden_state": optional, when output_embeddings=True; shape (batch_size, seq_len, embedding_dim)

ViTModel.freeze_backbone()

model.freeze_backbone(freeze: bool = True)  # Freeze/unfreeze backbone

Supported Models

Any ViT-compatible model on the HuggingFace Hub works. Examples:

  • google/vit-base-patch16-224, google/vit-large-patch16-224 (ViT)
  • facebook/deit-base-distilled-patch16-224 (DeiT)
  • facebook/dino-vitb16 (DINO)
  • facebook/dinov2-base, facebook/dinov2-with-registers-base (DINOv2)
  • facebook/dinov3-vitb16-pretrain-lvd1689m (DINOv3)
  • openai/clip-vit-base-patch16 (CLIP Vision; pass backbone_cls=CLIPVisionModel)

Browse the HuggingFace Hub for more models.

Import Patterns

from vit_zoo import ViTModel
from vit_zoo.factory import build_model
from vit_zoo.components import ViTBackbone, BaseHead, LinearHead, MLPHead, IdentityHead

Available Heads

  • LinearHead: Simple linear layer (auto-created when head=int)
  • MLPHead: Multi-layer perceptron with configurable depth, activation, dropout
  • IdentityHead: Returns embeddings unchanged

All heads must implement input_dim property. Custom heads by subclassing BaseHead.

License

GPL-3.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vit_zoo-0.2.0.tar.gz (12.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vit_zoo-0.2.0-py3-none-any.whl (12.1 kB view details)

Uploaded Python 3

File details

Details for the file vit_zoo-0.2.0.tar.gz.

File metadata

  • Download URL: vit_zoo-0.2.0.tar.gz
  • Upload date:
  • Size: 12.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.2.0.tar.gz
Algorithm Hash digest
SHA256 8647cbea8f48965505064ecfb35da4f19f752cbc8150efa8b49fc452eee85b43
MD5 ba525c5a939a60fe92a7b2bcc884f3e2
BLAKE2b-256 7c10862e54250d2bde63815d688be64bdb98f4536ba9230986743e34ea8e5958

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.2.0.tar.gz:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file vit_zoo-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: vit_zoo-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 12.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7f5583eeea554b847ff80fea53212ccf2aedce522e8fae774cf27eb394248f40
MD5 25a8b826d5f1ce3642467b3c030de370
BLAKE2b-256 f5410c68e8c8be2dad73dc5be209f433b876542df32dc4fcef53516b0d7b1673

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.2.0-py3-none-any.whl:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page