Skip to main content

Vision Transformers Zoo

Project description

vit_zoo logo

PyPI Python Tests Source

A clean, extensible factory for creating HuggingFace-based Vision Transformer models (ViT, DeiT, DINO, DINOv2, DINOv3, CLIP) with flexible heads and easy backbone freezing.

Installation

pip install vit-zoo

From source:

git clone https://github.com/jbindaAI/vit_zoo.git
cd vit_zoo
pip install -e .

For development: pip install -e ".[dev]"

Quick start

from vit_zoo import ViTModel

model = ViTModel("facebook/dinov2-base", head=10, freeze_backbone=True)
outputs = model(images)
logits = outputs["predictions"]  # (batch_size, 10)

Basic usage

from vit_zoo import ViTModel

# Simple classification - pass any HuggingFace model ID
model = ViTModel("google/vit-base-patch16-224", head=10, freeze_backbone=True)
outputs = model(images)
predictions = outputs["predictions"]  # Shape: (batch_size, 10)

Custom MLP Head

from vit_zoo import ViTModel, MLPHead

mlp_head = MLPHead(
    input_dim=768,
    hidden_dims=[512, 256],
    output_dim=100,
    dropout=0.1,
    activation="gelu"  # or 'relu', 'tanh', or nn.Module
)

model = ViTModel("facebook/dinov2-base", head=mlp_head)

Embedding Extraction

from vit_zoo import ViTModel
from transformers import CLIPVisionModel

model = ViTModel("openai/clip-vit-base-patch16", backbone_cls=CLIPVisionModel, head=None)
outputs = model(images, output_hidden_states=True)
hidden_states = outputs["last_hidden_state"]  # (batch_size, seq_len, embedding_dim)
cls_embedding = hidden_states[:, 0, :]  # (batch_size, embedding_dim)
predictions = outputs["predictions"]  # same as cls_embedding when head=None (IdentityHead)

Attention Weights

from vit_zoo import ViTModel

model = ViTModel(
    "google/vit-base-patch16-224",
    head=10,
    config_kwargs={"attn_implementation": "eager"}
)
outputs = model(images, output_attentions=True)
attentions = outputs["attentions"]

Custom Head

from vit_zoo import ViTModel, BaseHead
import torch.nn as nn

class CustomHead(BaseHead):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self._input_dim = input_dim
        self.fc = nn.Linear(input_dim, num_classes)

    @property
    def input_dim(self) -> int:
        return self._input_dim

    def forward(self, embeddings):
        return self.fc(embeddings)

head = CustomHead(input_dim=768, num_classes=10)
model = ViTModel("google/vit-base-patch16-224", head=head)

Multi-modal Models (CLIP)

For CLIP and other multi-modal models, pass backbone_cls to load only the vision encoder (AutoModel would load the full model):

from vit_zoo import ViTModel
from transformers import CLIPVisionModel

model = ViTModel("openai/clip-vit-base-patch16", backbone_cls=CLIPVisionModel, head=10)

Any HuggingFace Model

ViTModel uses AutoModel to auto-detect the model type from the HuggingFace Hub. Any ViT-compatible model works:

from vit_zoo import ViTModel

model = ViTModel("google/vit-large-patch16-224", head=10)
model = ViTModel("facebook/deit-base-distilled-patch16-224", head=10)
model = ViTModel("facebook/dinov2-with-registers-base", head=10)

API Reference

ViTModel

Single entry point: construct from a HuggingFace model name or from a pre-built backbone.

ViTModel(
    model_name: Optional[str] = None,
    head: Optional[Union[int, BaseHead]] = None,
    backbone: Optional[nn.Module] = None,
    backbone_cls: Optional[Type] = None,
    freeze_backbone: bool = False,
    load_pretrained: bool = True,
    backbone_dropout: float = 0.0,
    config_kwargs: Optional[Dict[str, Any]] = None,
)

Parameters:

  • model_name: HuggingFace model identifier (e.g. "google/vit-base-patch16-224"). Required unless backbone is provided.
  • head: int (creates LinearHead), BaseHead instance, or None (embedding extraction).
  • backbone: Optional pre-built backbone; if set, model_name and backbone-loading args are ignored.
  • backbone_cls: Optional HuggingFace model class (e.g. CLIPVisionModel). Use for multi-modal models.
  • freeze_backbone: Freeze all backbone parameters.
  • load_pretrained: Load pretrained weights when using model_name.
  • backbone_dropout: Dropout probability in backbone.
  • config_kwargs: Extra config options (e.g. {"attn_implementation": "eager"}).

Usage:

  • ViTModel("google/vit-base-patch16-224", head=10)
  • ViTModel("facebook/dinov2-base", head=None) (embedding extraction)
  • Custom backbone: backbone = vit_zoo.utils._load_backbone(...); ViTModel(backbone=backbone, head=10)

ViTModel.forward()

forward(
    pixel_values: torch.Tensor,
    output_attentions: bool = False,
    output_hidden_states: bool = False,
) -> Dict[str, Any]

Always returns a dict. Keys:

  • "predictions": head output tensor (always present)
  • "attentions": optional, when output_attentions=True
  • "last_hidden_state": optional, when output_hidden_states=True; shape (batch_size, seq_len, embedding_dim)

Freezing the backbone

model.freeze_backbone(freeze: bool = True)  # Freeze/unfreeze backbone

The backbone is the raw HuggingFace model (e.g., model.backbone.encoder.layer.11 for ViT), so you can register hooks and access layers directly without an extra wrapper.

Supported Models

Any ViT-compatible model on the HuggingFace Hub works. Examples:

  • google/vit-base-patch16-224, google/vit-large-patch16-224 (ViT)
  • facebook/deit-base-distilled-patch16-224 (DeiT)
  • facebook/dino-vitb16 (DINO)
  • facebook/dinov2-base, facebook/dinov2-with-registers-base (DINOv2)
  • facebook/dinov3-vitb16-pretrain-lvd1689m (DINOv3)
  • openai/clip-vit-base-patch16 (CLIP Vision; pass backbone_cls=CLIPVisionModel)

Browse the HuggingFace Hub for more models.

Import Patterns

You can import the public API from the root package or from submodules:

# One-line style (recommended)
from vit_zoo import ViTModel, BaseHead, LinearHead, MLPHead, IdentityHead

# Submodule style (explicit namespaces)
from vit_zoo import ViTModel
from vit_zoo.components import BaseHead, LinearHead, MLPHead, IdentityHead
from vit_zoo.utils import _load_backbone  # for custom backbone path (private)

Architecture

  • Public API (vit_zoo.__all__): ViTModel, BaseHead, LinearHead, MLPHead, IdentityHead.
  • Layout: vit_zoo/model.py (ViT model), vit_zoo/utils/backbone.py (_load_backbone, _get_embedding_dim, _get_cls_token_embedding), vit_zoo/components/ (heads).
  • Extending: Add new heads in components; use vit_zoo.utils._load_backbone for custom backbones, then ViTModel(backbone=..., head=...).

Available Heads

  • LinearHead: Simple linear layer (auto-created when head=int)
  • MLPHead: Multi-layer perceptron with configurable depth, activation, dropout
  • IdentityHead: Returns embeddings unchanged

All heads must implement input_dim property. Custom heads by subclassing BaseHead.

License

GPL-3.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vit_zoo-0.2.1.tar.gz (15.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vit_zoo-0.2.1-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file vit_zoo-0.2.1.tar.gz.

File metadata

  • Download URL: vit_zoo-0.2.1.tar.gz
  • Upload date:
  • Size: 15.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.2.1.tar.gz
Algorithm Hash digest
SHA256 459e903c46c9917329ff5cfc83a83a086c096a5359d97af9c2e770b84f099aab
MD5 87e0b140778696e3db8929e673c1bf54
BLAKE2b-256 0d9706d4395884b2c67d7aa13204ea9899fd407f7f12fb7766bd8e8a49432fb4

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.2.1.tar.gz:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file vit_zoo-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: vit_zoo-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 10.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for vit_zoo-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 551dbe804429ae0abf960583ca1302d18dfa20d84ee852a1da6aba6a4408b432
MD5 81fdbb3f9511de11bb0566eb634ee0e1
BLAKE2b-256 a054dcd11c2bc62bbcf506ee7ad000ab09b1de358deb1cd01eeb589b45cd9ac8

See more details on using hashes here.

Provenance

The following attestation bundles were made for vit_zoo-0.2.1-py3-none-any.whl:

Publisher: release.yml on jbindaAI/vit_zoo

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page