Skip to main content

Vision Transformers Zoo

Project description

Vision Transformer Zoo

A clean, extensible factory for creating HuggingFace-based Vision Transformer models (ViT, DeiT, DINO, DINOv2, DINOv3, CLIP) with flexible heads and easy backbone freezing.


Installation (PyPI)

pip3 install vit_zoo

Installation (From source)

git clone git@github.com:jbindaAI/vit_zoo.git
cd vit_zoo
pip install -e .

For development: pip install -e ".[dev]"


Quick Start

Basic Usage

from vit_zoo.factory import build_model

# Simple classification
model = build_model("vanilla_vit", head=10, freeze_backbone=True)
predictions = model(images)  # Shape: (batch_size, 10)

Custom MLP Head

from vit_zoo.factory import build_model
from vit_zoo.components import MLPHead

mlp_head = MLPHead(
    input_dim=768,
    hidden_dims=[512, 256],
    output_dim=100,
    dropout=0.1,
    activation="gelu"  # or 'relu', 'tanh', or nn.Module
)

model = build_model("dinov2_vit", head=mlp_head)

Embedding Extraction

model = build_model("clip_vit", head=None)
outputs = model(images, output_embeddings=True)
embeddings = outputs["embeddings"]  # Shape: (batch_size, embedding_dim)

Attention Weights

model = build_model(
    "vanilla_vit",
    head=10,
    config_kwargs={"attn_implementation": "eager"}
)
outputs = model(images, output_attentions=True)
attentions = outputs["attentions"]

Custom Head

from vit_zoo.components import BaseHead
import torch.nn as nn

class CustomHead(BaseHead):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self._input_dim = input_dim
        self.fc = nn.Linear(input_dim, num_classes)
    
    @property
    def input_dim(self) -> int:
        return self._input_dim
    
    def forward(self, embeddings):
        return self.fc(embeddings)

head = CustomHead(input_dim=768, num_classes=10)
model = build_model("vanilla_vit", head=head)

Direct Usage (Any HuggingFace Model)

from vit_zoo.factory import build_model
from transformers import ViTModel

model = build_model(
    model_name="google/vit-large-patch16-224",
    backbone_cls=ViTModel,
    head=10
)

API Reference

build_model()

build_model(
    model_type: Optional[str] = None,
    model_name: Optional[str] = None,
    backbone_cls: Optional[Type[ViTBackboneProtocol]] = None,
    head: Optional[Union[int, BaseHead]] = None,
    freeze_backbone: bool = False,
    load_pretrained: bool = True,
    backbone_dropout: float = 0.0,
    config_kwargs: Optional[Dict[str, Any]] = None,
) -> ViTModel

Parameters:

  • model_type: Registry key ("vanilla_vit", "deit_vit", "dinov2_vit", etc.)
  • head: int (creates LinearHead), BaseHead instance, or None (embedding extraction)
  • freeze_backbone: Freeze all backbone parameters
  • config_kwargs: Extra config options (e.g., {"attn_implementation": "eager"})

Usage:

  • Registry: build_model("vanilla_vit", head=10)
  • Override: build_model("vanilla_vit", model_name="google/vit-large-patch16-224", head=10)
  • Direct: build_model(model_name="...", backbone_cls=ViTModel, head=10)

ViTModel.forward()

forward(
    pixel_values: torch.Tensor,
    output_attentions: bool = False,
    output_embeddings: bool = False,
) -> Union[torch.Tensor, Dict[str, Any]]

Returns predictions tensor, or dict with "predictions", "attentions", "embeddings" keys.

ViTModel.freeze_backbone()

model.freeze_backbone(freeze: bool = True)  # Freeze/unfreeze backbone

list_models()

from vit_zoo.factory import list_models
available = list_models()  # Returns list of registered model types

Supported Models

  • vanilla_vit: Google ViT (google/vit-base-patch16-224)
  • deit_vit: Facebook DeiT (facebook/deit-base-distilled-patch16-224)
  • dino_vit: Facebook DINO (facebook/dino-vitb16)
  • dinov2_vit: Facebook DINOv2 (facebook/dinov2-base)
  • dinov2_reg_vit: DINOv2 with registers (facebook/dinov2-with-registers-base)
  • dinov3_vit: Facebook DINOv3 (facebook/dinov3-vitb16-pretrain-lvd1689m)
  • clip_vit: OpenAI CLIP Vision (openai/clip-vit-base-patch16)

Import Patterns

from vit_zoo import ViTModel
from vit_zoo.factory import build_model, list_models
from vit_zoo.components import ViTBackbone, BaseHead, LinearHead, MLPHead, IdentityHead

Available Heads

  • LinearHead: Simple linear layer (auto-created when head=int)
  • MLPHead: Multi-layer perceptron with configurable depth, activation, dropout
  • IdentityHead: Returns embeddings unchanged

All heads must implement input_dim property. Custom heads by subclassing BaseHead.


License

GPL-3.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vit_zoo-0.1.1.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vit_zoo-0.1.1-py3-none-any.whl (12.8 kB view details)

Uploaded Python 3

File details

Details for the file vit_zoo-0.1.1.tar.gz.

File metadata

  • Download URL: vit_zoo-0.1.1.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.1.1.tar.gz
Algorithm Hash digest
SHA256 f4ce1773c2432d57c8e9fb32d48d588a85b6d60a52e30622f0a65629088711f7
MD5 35fa81cdd5d3cd9f9db7a2d5d0af5081
BLAKE2b-256 d17e9849a077a5012f028e66a8ea901353f9199c42174e0ac562649c403ad926

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.1.1.tar.gz:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file vit_zoo-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: vit_zoo-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 12.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2cc0406b5bbd662fc683d68a61a9c2b8124bb71347ace7adea0f7906a344493f
MD5 1d3888cfbc573f8be4f5e278b0ba6628
BLAKE2b-256 e090c254ada3b498a85678812a1f61b7f007289c7ccc3d3caa1933f3083ebb5a

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.1.1-py3-none-any.whl:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page