Skip to main content

Vision Transformers Zoo

Project description

vit_zoo logo

PyPI Python Tests Source

A clean, extensible factory for creating HuggingFace-based Vision Transformer models (ViT, DeiT, DINO, DINOv2, DINOv3, CLIP) with flexible heads and easy backbone freezing.

Installation

pip install vit_zoo

From source:

git clone https://github.com/jbindaAI/vit_zoo.git
cd vit_zoo
pip install -e .

For development: pip install -e ".[dev]"

Quick start

from vit_zoo.factory import build_model

model = build_model("dinov2_vit", head=10, freeze_backbone=True)
outputs = model(images)
logits = outputs["predictions"]  # (batch_size, 10)

Basic usage

from vit_zoo.factory import build_model

# Simple classification
model = build_model("vanilla_vit", head=10, freeze_backbone=True)
outputs = model(images)
predictions = outputs["predictions"]  # Shape: (batch_size, 10)

Custom MLP Head

from vit_zoo.factory import build_model
from vit_zoo.components import MLPHead

mlp_head = MLPHead(
    input_dim=768,
    hidden_dims=[512, 256],
    output_dim=100,
    dropout=0.1,
    activation="gelu"  # or 'relu', 'tanh', or nn.Module
)

model = build_model("dinov2_vit", head=mlp_head)

Embedding Extraction

model = build_model("clip_vit", head=None)
outputs = model(images, output_embeddings=True)
hidden_states = outputs["last_hidden_state"]  # (batch_size, seq_len, embedding_dim)
cls_embedding = hidden_states[:, 0, :]  # (batch_size, embedding_dim)
predictions = outputs["predictions"]  # same as cls_embedding when head=None (IdentityHead)

Attention Weights

model = build_model(
    "vanilla_vit",
    head=10,
    config_kwargs={"attn_implementation": "eager"}
)
outputs = model(images, output_attentions=True)
attentions = outputs["attentions"]

Custom Head

from vit_zoo.components import BaseHead
import torch.nn as nn

class CustomHead(BaseHead):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self._input_dim = input_dim
        self.fc = nn.Linear(input_dim, num_classes)
    
    @property
    def input_dim(self) -> int:
        return self._input_dim
    
    def forward(self, embeddings):
        return self.fc(embeddings)

head = CustomHead(input_dim=768, num_classes=10)
model = build_model("vanilla_vit", head=head)

Direct Usage (Any HuggingFace Model)

from vit_zoo.factory import build_model
from transformers import ViTModel

model = build_model(
    model_name="google/vit-large-patch16-224",
    backbone_cls=ViTModel,
    head=10
)

API Reference

build_model()

build_model(
    model_type: Optional[str] = None,
    model_name: Optional[str] = None,
    backbone_cls: Optional[Type[ViTBackboneProtocol]] = None,
    head: Optional[Union[int, BaseHead]] = None,
    freeze_backbone: bool = False,
    load_pretrained: bool = True,
    backbone_dropout: float = 0.0,
    config_kwargs: Optional[Dict[str, Any]] = None,
) -> ViTModel

Parameters:

  • model_type: Registry key ("vanilla_vit", "deit_vit", "dinov2_vit", etc.)
  • head: int (creates LinearHead), BaseHead instance, or None (embedding extraction)
  • freeze_backbone: Freeze all backbone parameters
  • config_kwargs: Extra config options (e.g., {"attn_implementation": "eager"})

Usage:

  • Registry: build_model("vanilla_vit", head=10)
  • Override: build_model("vanilla_vit", model_name="google/vit-large-patch16-224", head=10)
  • Direct: build_model(model_name="...", backbone_cls=ViTModel, head=10)

ViTModel.forward()

forward(
    pixel_values: torch.Tensor,
    output_attentions: bool = False,
    output_embeddings: bool = False,
) -> Dict[str, Any]

Always returns a dict. Keys:

  • "predictions": head output tensor (always present)
  • "attentions": optional, when output_attentions=True
  • "last_hidden_state": optional, when output_embeddings=True; shape (batch_size, seq_len, embedding_dim)

ViTModel.freeze_backbone()

model.freeze_backbone(freeze: bool = True)  # Freeze/unfreeze backbone

list_models()

from vit_zoo.factory import list_models
available = list_models()  # Returns list of registered model types

Default Models

  • vanilla_vit: Google ViT (google/vit-base-patch16-224)
  • deit_vit: Facebook DeiT (facebook/deit-base-distilled-patch16-224)
  • dino_vit: Facebook DINO (facebook/dino-vitb16)
  • dinov2_vit: Facebook DINOv2 (facebook/dinov2-base)
  • dinov2_reg_vit: DINOv2 with registers (facebook/dinov2-with-registers-base)
  • dinov3_vit: Facebook DINOv3 (facebook/dinov3-vitb16-pretrain-lvd1689m)
  • clip_vit: OpenAI CLIP Vision (openai/clip-vit-base-patch16)

Import Patterns

from vit_zoo import ViTModel
from vit_zoo.factory import build_model, list_models
from vit_zoo.components import ViTBackbone, BaseHead, LinearHead, MLPHead, IdentityHead

Available Heads

  • LinearHead: Simple linear layer (auto-created when head=int)
  • MLPHead: Multi-layer perceptron with configurable depth, activation, dropout
  • IdentityHead: Returns embeddings unchanged

All heads must implement input_dim property. Custom heads by subclassing BaseHead.

License

GPL-3.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vit_zoo-0.1.5.tar.gz (13.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vit_zoo-0.1.5-py3-none-any.whl (13.1 kB view details)

Uploaded Python 3

File details

Details for the file vit_zoo-0.1.5.tar.gz.

File metadata

  • Download URL: vit_zoo-0.1.5.tar.gz
  • Upload date:
  • Size: 13.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.1.5.tar.gz
Algorithm Hash digest
SHA256 07e08c5baad665ec6383d1aa2d9968385439995d1b957ecb89caf77d4939ef22
MD5 140dda287731fe923e58a1bd30431c25
BLAKE2b-256 f127cb66d8056c9ca8b00320a2f72d62d27a33249127b93d4cb1e710aaa11e37

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.1.5.tar.gz:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file vit_zoo-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: vit_zoo-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 13.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 67f31ffb197b788160cee97d192b97d9b85c8a0a4c86c28423c2c7773b49001e
MD5 f1b17ef8e30a44aff7d2f55d7f74de1a
BLAKE2b-256 95f44feb2e4187d4c4ce9d2dd1c4df31b073fd99ad872066448eafff66ae4735

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.1.5-py3-none-any.whl:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page