Skip to main content

Vision Transformers Zoo

Project description

Vision Transformer Zoo

A clean, extensible factory for creating HuggingFace-based Vision Transformer models (ViT, DeiT, DINO, DINOv2, DINOv3, CLIP) with flexible heads and easy backbone freezing.


Installation (PyPI)

pip3 install vit_zoo

Installation (From source)

git clone git@github.com:jbindaAI/vit_zoo.git
cd vit_zoo
pip install -e .

For development: pip install -e ".[dev]"


Quick Start

Basic Usage

from vit_zoo.factory import build_model

# Simple classification
model = build_model("vanilla_vit", head=10, freeze_backbone=True)
predictions = model(images)  # Shape: (batch_size, 10)

Custom MLP Head

from vit_zoo.factory import build_model
from vit_zoo.components import MLPHead

mlp_head = MLPHead(
    input_dim=768,
    hidden_dims=[512, 256],
    output_dim=100,
    dropout=0.1,
    activation="gelu"  # or 'relu', 'tanh', or nn.Module
)

model = build_model("dinov2_vit", head=mlp_head)

Embedding Extraction

model = build_model("clip_vit", head=None)
outputs = model(images, output_embeddings=True)
embeddings = outputs["embeddings"]  # Shape: (batch_size, embedding_dim)

Attention Weights

model = build_model(
    "vanilla_vit",
    head=10,
    config_kwargs={"attn_implementation": "eager"}
)
outputs = model(images, output_attentions=True)
attentions = outputs["attentions"]

Custom Head

from vit_zoo.components import BaseHead
import torch.nn as nn

class CustomHead(BaseHead):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self._input_dim = input_dim
        self.fc = nn.Linear(input_dim, num_classes)
    
    @property
    def input_dim(self) -> int:
        return self._input_dim
    
    def forward(self, embeddings):
        return self.fc(embeddings)

head = CustomHead(input_dim=768, num_classes=10)
model = build_model("vanilla_vit", head=head)

Direct Usage (Any HuggingFace Model)

from vit_zoo.factory import build_model
from transformers import ViTModel

model = build_model(
    model_name="google/vit-large-patch16-224",
    backbone_cls=ViTModel,
    head=10
)

API Reference

build_model()

build_model(
    model_type: Optional[str] = None,
    model_name: Optional[str] = None,
    backbone_cls: Optional[Type[ViTBackboneProtocol]] = None,
    head: Optional[Union[int, BaseHead]] = None,
    freeze_backbone: bool = False,
    load_pretrained: bool = True,
    backbone_dropout: float = 0.0,
    config_kwargs: Optional[Dict[str, Any]] = None,
) -> ViTModel

Parameters:

  • model_type: Registry key ("vanilla_vit", "deit_vit", "dinov2_vit", etc.)
  • head: int (creates LinearHead), BaseHead instance, or None (embedding extraction)
  • freeze_backbone: Freeze all backbone parameters
  • config_kwargs: Extra config options (e.g., {"attn_implementation": "eager"})

Usage:

  • Registry: build_model("vanilla_vit", head=10)
  • Override: build_model("vanilla_vit", model_name="google/vit-large-patch16-224", head=10)
  • Direct: build_model(model_name="...", backbone_cls=ViTModel, head=10)

ViTModel.forward()

forward(
    pixel_values: torch.Tensor,
    output_attentions: bool = False,
    output_embeddings: bool = False,
) -> Union[torch.Tensor, Dict[str, Any]]

Returns predictions tensor, or dict with "predictions", "attentions", "embeddings" keys.

ViTModel.freeze_backbone()

model.freeze_backbone(freeze: bool = True)  # Freeze/unfreeze backbone

list_models()

from vit_zoo.factory import list_models
available = list_models()  # Returns list of registered model types

Supported Models

  • vanilla_vit: Google ViT (google/vit-base-patch16-224)
  • deit_vit: Facebook DeiT (facebook/deit-base-distilled-patch16-224)
  • dino_vit: Facebook DINO (facebook/dino-vitb16)
  • dinov2_vit: Facebook DINOv2 (facebook/dinov2-base)
  • dinov2_reg_vit: DINOv2 with registers (facebook/dinov2-with-registers-base)
  • dinov3_vit: Facebook DINOv3 (facebook/dinov3-vitb16-pretrain-lvd1689m)
  • clip_vit: OpenAI CLIP Vision (openai/clip-vit-base-patch16)

Import Patterns

from vit_zoo import ViTModel
from vit_zoo.factory import build_model, list_models
from vit_zoo.components import ViTBackbone, BaseHead, LinearHead, MLPHead, IdentityHead

Available Heads

  • LinearHead: Simple linear layer (auto-created when head=int)
  • MLPHead: Multi-layer perceptron with configurable depth, activation, dropout
  • IdentityHead: Returns embeddings unchanged

All heads must implement input_dim property. Custom heads by subclassing BaseHead.


License

GPL-3.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vit_zoo-0.1.2.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vit_zoo-0.1.2-py3-none-any.whl (12.8 kB view details)

Uploaded Python 3

File details

Details for the file vit_zoo-0.1.2.tar.gz.

File metadata

  • Download URL: vit_zoo-0.1.2.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.1.2.tar.gz
Algorithm Hash digest
SHA256 7965db6265f2f01ba7ce0e8742c609b5fee0c513812fe200f462acd0c5effce1
MD5 524bcbb616fdcc37d9462a012ab3fbd7
BLAKE2b-256 6f31b012afb3ba92023f56b6b57a48ac313401c5332179c9ae40129fdd48d587

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.1.2.tar.gz:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file vit_zoo-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: vit_zoo-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 12.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 bfe70203a756f6bca59962b2bf521215d635e7ce098b0084d7915b9758ae9d58
MD5 ae40c72a94bdd77886ac771ac41c5f59
BLAKE2b-256 a31f42db1f0d8f381385864334df0bd7224b3049c38af55f850122c5c59528f0

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.1.2-py3-none-any.whl:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page