Skip to main content

Akriel's vision models playground

Project description

Vision Models - Playground

Table of Contents



This playground is a collection of vision models implemented by me from scratch in PyTorch, with the purpose of getting a better understanding of the specific papers and techniques used.


In order to install this package, run the following command:

$ pip install vision-models-playground


A detector based on the YoloV1 architecture. Also known as Darknet.


Models can be initialized with pre-build or custom versions.

Code example to initialize and use prebuild YoloV1

import torch
from vision_models_playground.models.segmentation import build_yolo_v1

model = build_yolo_v1(num_classes=20, in_channels=3, grid_size=7, num_bounding_boxes=2)

img = torch.randn(1, 3, 448, 448)  # <batch_size, in_channels, height, width>
preds = model(img)  # (1, 7, 7, 30) <batch_size, grid_size, grid_size, num_classes + 5 * num_bounding_boxes>

Or if you want to use the custom YoloV1

import torch
from vision_models_playground.models.segmentation import YoloV1

dims = [[64], [192], [128, 256, 256, 512], [256, 512, 256, 512, 256, 512, 256, 512, 512, 1024], [512, 1024, 512, 1024]]
kernel_size = [[7], [3], [1, 3, 1, 3], [1, 3, 1, 3, 1, 3, 1, 3, 1, 3], [1, 3, 1, 3]]
stride = [[2], [1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1]]
max_pools = [True, True, True, True, False]

model = YoloV1(



img = torch.randn(1, 3, 448, 448)  # <batch_size, in_channels, height, width>
preds = model(img)  # (1, 7, 7, 30) <batch_size, grid_size, grid_size, num_classes + 5 * num_bounding_boxes>


  • dims: List[List[int]].
    The number of channels in each layer.

  • kernel_size: List[List[int]].
    The kernel size in each layer.

  • stride: List[List[int]].
    The stride in each layer.

  • max_pools: List[bool].
    If enabled, a max pool layer is applied after the convolutional layer.

  • in_channels: int.
    The number of channels in the input image.

  • num_classes: int.
    The number of classes to classify.

  • num_bounding_boxes: int.
    The number of bounding boxes to predict per grid cell.

  • grid_size: int.
    The size of the grid that the image is split into.

  • mlp_size: int.
    The size of the MLP layer that is applied after the convolutional layers.

PreTrained Weights

If you want to use the pretrained weights, you use the following code:

import torch
from vision_models_playground.utility.hub import load_vmp_model_from_hub

model = load_vmp_model_from_hub("Akriel/ResNetYoloV1")

img = torch.randn(1, 3, 448, 448)  # <batch_size, in_channels, height, width>
preds = model(img)  # (1, 7, 7, 30) <batch_size, grid_size, grid_size, num_classes + 5 * num_bounding_boxes>

If you want to use the pretrained model within a pipeline on raw images, you can use the following code:

from PIL import Image

from vision_models_playground.utility.hub import load_vmp_pipeline_from_hub

pipeline = load_vmp_pipeline_from_hub("Akriel/ResNetYoloV1")

# Load image
x ="path/to/image.jpg")
y = pipeline(x)  # A dictionary containing the predictions

For more information about the pretrain models, check the hub page.
Or check the demo results from explore_models/yolo_trained_model.ipynb


A classifier based on the ResNet architecture.


Models can be initialized with pre-build or custom versions.

Pre-build models:

  • ResNet18
  • ResNet34
  • ResNet50
  • ResNet101
  • ResNet152

Code example to initialize and use prebuild ResNet34

import torch
from vision_models_playground.models.classifiers import build_resnet_34

model = build_resnet_34(num_classes=10, in_channels=3)

img = torch.randn(1, 3, 256, 256)  # <batch_size, in_channels, height, width>
preds = model(img)  # (1, 10) <batch_size, num_classes>

Code example to initialize ResNet34 using the custom ResNet

import torch
from vision_models_playground.models.classifiers import ResNet
from vision_models_playground.components.convolutions import ResidualBlock

model = ResNet(
    num_layers=[3, 4, 6, 3],
    num_channels=[64, 128, 256, 512],

img = torch.randn(1, 3, 256, 256)  # <batch_size, in_channels, height, width>
preds = model(img)  # (1, 10) <batch_size, num_classes>


  • in_channels: int.
    The number of channels in the input image.

  • num_classes: int.
    The number of predicted classes

  • num_layers: List[int]
    The number of block layers in each stage.

  • num_channels: List[int]
    The number of channels in each stage.
    Each stage will start with a stride of 2, connecting the previous stage channels with the current stage channels.

  • block: Union[ResidualBlock, BottleneckBlock]
    The block type used.
    There are two pre-implemented block types: ResidualBlock and BottleneckBlock.
    Can be replaced with any custom block that has the following params in the constructor: in_channels, out_channels, stride.

Vision Transformer (ViT)

A classifier based on the Vision Transformer architecture.


Code example to initialize and use Vision Transformer

import torch
from vision_models_playground.models.classifiers import VisionTransformer

model = VisionTransformer(

img = torch.randn(1, 3, 256, 256)
preds = model(img)  # (1, 1000)


  • image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height

  • patch_size: int.
    Number of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.

  • num_classes: int.
    Number of classes to classify.

  • projection_dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(..., dim).

  • depth: int.
    Number of Transformer blocks.

  • heads: int.
    Number of heads in Multi-head Attention layer.

  • mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.

  • channels: int, default 3.
    Number of image's channels.

  • dropout: float between [0, 1], default 0.
    Dropout rate.

  • emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.

  • dim_head: int, default to 64.
    The dim for each head for Multi-Head Attention.

  • pool: string, either cls or mean, default to mean
    Determines if token pooling or mean pooling is applied

  • apply_rotary_emb: bool, default False.
    If enabled, applies rotary_embedding in Attention blocks.

Generative Adversarial Networks (GAN)

A generative model based on the GAN architecture.


Since the generated images must have a certain shape, the GAN model receives both the Generator and the Discriminator as input.

The GAN is taking care of the training process, by computing the loss and updating the weights of the Generator and Discriminator.

Here is a code example that shows how to use the GAN interface to train on the MNIST dataset.

import torch

from torchvision import datasets, transforms

from vision_models_playground.models.generative import GAN

# Import custom Generator and Discriminator adequate to the problem
from vision_models_playground.models.generative.adverserial.gan import Discriminator
from vision_models_playground.models.generative.adverserial.gan import Generator

# Create GAN
generator = Generator()
discriminator = Discriminator()
gan = GAN(generator, discriminator)

# Put model on cuda

# Create the data loader
train_loader =
    datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([

# Train the GAN
gan.train_epochs(train_loader, epochs=100, print_every=100)


  • generator: nn.Module
    The Generator model.
    Must have self.noise_dim set to the dimension of the noise vector used by the Generator in the forward step.

  • discriminator: nn.Module
    The Discriminator model. The output of the Discriminator must have shape (<batch_size, 1), having the probability of the image being real.


This is a sample of the results of the GAN on MNIST.

For reference, this is a sample of the original MNIST dataset.

Known issues

At this moment, the gan is coded to operate only on CUDA devices. In future the code will be refactored to allow the use of CPU devices too.


A classifier based on the Perceiver architecture.


Code example to initialize and use Perceiver

import torch
from vision_models_playground.models.classifiers import Perceiver

model = Perceiver(

img = torch.randn(1, 256, 256, 3)
preds = model(img)  # (1, 1000)


  • input_dim: int.
    Number of channels of the input.

  • input_axis: int.
    Number of axis of the input.
    If the input is a sequence, the input_axis is 1
    If the input is an image, the input_axis is 2.
    If the input is a video, the input_axis is 3.

  • final_classifier_head: bool.
    If enabled, the final classifier head is applied, and logits are returned.
    If disabled, the final classifier head is not applied, and latents are returned.

  • num_classes: int.
    Number of classes to classify.

  • apply_rotary_emb: bool.
    If enabled, applies rotary_embedding in Attention blocks.

  • apply_fourier_encoding: bool.
    If enabled, applies fourier_encoding over the input

  • max_freq: int.
    Maximum frequency to be used in fourier_encoding.

  • num_freq_bands: int.
    Number of frequency bands to be used in fourier_encoding.

  • constant_mapping: bool. If enabled, uses a constant mapping for the axis of the fourier_encoding.

  • max_position: int.
    Maximum position to be used in the positional fourier encoding.
    Works only if constant_mapping is enabled.

  • num_layers: int.
    Number of layers

  • num_latents: int.
    Number of latents

  • latent_dim: int.
    Dimension of the latent vector

  • cross_num_heads: int.
    Number of heads in the cross attention blocks

  • cross_head_dim: int.
    Dimension of the heads in the cross attention blocks

  • self_attend_heads: int.
    Number of heads in the self attention blocks

  • self_attend_dim: int.
    Dimension of the heads in the self attention blocks

  • transformer_depth: int.
    Number of layers in the transformer

  • attention_dropout: float.
    Dropout probability for the attention layers

  • ff_hidden_dim: int.
    Dimension of the hidden layers in the feed forward blocks

  • ff_dropout: float.
    Dropout probability for the feed forward layers

  • activation: Callable.
    Activation function to be used in the feed forward blocks.
    If left as None, GEGLU is used.

Vision Perceiver (ViP)

A classifier based on the Perceiver architecture, but adapted to work with the technique of the Vision Transformer by splitting the image into patches, and projecting them into a sequence.


Code example to initialize and use Vision Perceiver

import torch
from vision_models_playground.models.classifiers import VisionPerceiver

model = VisionPerceiver(

img = torch.randn(1, 256, 256, 3)
preds = model(img)  # (1, 1000)


  • patch_size: int.
    Size of the patches the image is split into.

  • projection_dim: int.
    Dimension of the projection layer.

  • num_classes: int.
    Number of classes to classify.

  • apply_rotary_emb: bool.
    If enabled, applies rotary_embedding in Attention blocks.

  • apply_fourier_encoding: bool.
    If enabled, applies fourier_encoding over the input

  • max_freq: int.
    Maximum frequency to be used in fourier_encoding.

  • num_freq_bands: int.
    Number of frequency bands to be used in fourier_encoding.

  • constant_mapping: bool, default False.
    If enabled, uses a constant mapping for the axis of the fourier_encoding.

  • max_position: int.
    Maximum position to be used in the positional fourier encoding.
    Works only if constant_mapping is enabled.

  • num_layers: int.
    Number of layers

  • num_latents: int.
    Number of latents

  • latent_dim: int.
    Dimension of the latent vector

  • cross_num_heads: int.
    Number of heads in the cross attention blocks

  • cross_head_dim: int.
    Dimension of the heads in the cross attention blocks

  • self_attend_heads: int.
    Number of heads in the self attention blocks

  • self_attend_dim: int.
    Dimension of the heads in the self attention blocks

  • transformer_depth: int.
    Number of layers in the transformer

  • attention_dropout: float.
    Dropout probability for the attention layers

  • ff_hidden_dim: int.
    Dimension of the hidden layers in the feed forward blocks

  • ff_dropout: float.
    Dropout probability for the feed forward layers

  • activation: callable.
    Activation function to be used in the feed forward blocks.
    If left as None, GEGLU is used.

Convolutional Vision Transformer (CvT)

A classifier based on the Convolutional Vision Transformer architecture.


Models can be initialized with pre-build or custom versions.

Pre-build models:

  • CvT13
  • CvT21
  • CvTW24

Code example to initialize and use prebuild CvT13

import torch
from vision_models_playground.models.classifiers import build_cvt_13

model = build_cvt_13(num_classes=1000, in_channels=3)

img = torch.randn(1, 256, 256, 3)
preds = model(img)  # (1, 1000)

Code example to initialize CvT13 using the custom Convolutional Vision Transformer

import torch
from vision_models_playground.models.classifiers import ConvVisionTransformer

model = ConvVisionTransformer(
    patch_size=[7, 3, 3],
    patch_stride=[4, 2, 2],
    patch_padding=[2, 1, 1],
    embedding_dim=[64, 192, 384],
    depth=[1, 2, 10],
    num_heads=[1, 3, 6],
    ff_hidden_dim=[256, 768, 1536],
    qkv_bias=[True, True, True],
    drop_rate=[0.0, 0.0, 0.0],
    attn_drop_rate=[0.0, 0.0, 0.0],
    drop_path_rate=[0.0, 0.0, 0.1],
    kernel_size=[3, 3, 3],
    stride_kv=[2, 2, 2],
    stride_q=[1, 1, 1],
    padding_kv=[1, 1, 1],
    padding_q=[1, 1, 1],
    method=['conv', 'conv', 'conv'],

img = torch.randn(1, 256, 256, 3)
preds = model(img)  # (1, 1000)


  • in_channels: int.
    Number of input channels.

  • num_classes: int.
    Number of classes to classify.

  • activation: callable.
    Activation function to be used in the feed forward blocks.
    If left as None, QuickGELU is used.

  • final_classifier_head: bool.
    If enabled, uses a final classifier head.
    If disabled, returns the image features.

  • patch_size: List[int].
    Size of the patches the image is split into, per each Vision Transformer layer. The patches can overlap

  • patch_stride: List[int].
    Stride of the patches, per each Vision Transformer layer.

  • patch_padding: List[int].
    Padding of the patches, per each Vision Transformer layer.

  • embedding_dim: List[int].
    Dimension of the embedding layers, per each Vision Transformer layer.

  • depth: List[int].
    The depth of each Vision Transformer layer.

  • num_heads: List[int].
    The number of heads in each Attention block, per each Vision Transformer layer.

  • ff_hidden_dim: List[int].
    Dimension of the hidden layers in the feed forward blocks, per each Vision Transformer layer.

  • qkv_bias: List[bool].
    If enabled, adds a bias to the query, key and value vectors, per each Vision Transformer layer.

  • drop_rate: List[float].
    The dropout rate for the dropout layers in the Vision Transformer, Feed Forward and the output of the Attention layers.

  • attn_drop_rate: List[float].
    The dropout rate for the dropout layers in the Attention layers.

  • drop_path_rate: List[float].
    The DropPath rate for the DropPath layers, per each Vision Transformer.
    The DropPath rate that is applied inside each Attend block for the residual connections is computed dynamically based on the depth of the Vision Transformer.

  • kernel_size: List[int].
    The kernel size of the convolutional layers, per each Vision Transformer layer.

  • stride_kv: List[int].
    The stride of the convolutional layers, used in the projection of the Keys and Values.

  • stride_q: List[int].
    The stride of the convolutional layers, used in the projection of the Queries.

  • padding_kv: List[int].
    The padding of the convolutional layers, used in the projection of the Keys and Values.

  • padding_q: List[int].
    The padding of the convolutional layers, used in the projection of the Queries.

  • method: List[Literal['conv', 'avg', 'linear']].
    The method of computing the projections of the Keys, Values and Queries.
    conv stand for convolutional normalized layer, followed by linear projection
    avg stands for average pool layer, followed by linear projection
    linear stands for linear projection.


A pixel-level classifier based on the UNet architecture.


Code example to initialize and use UNet

import torch
from vision_models_playground.models.segmentation import UNet

model = UNet(
    channels=[64, 128, 256, 512, 1024],
x = torch.randn(1, 1, 572, 572)
y = model(x)  # (1, 2, 388, 388)


  • in_channels: int.
    Number of input channels.

  • out_channels: int.
    Number of output channels.
    Can be used as number of classes per pixel.
    In case of segmentation, the number of classes can be 2 for example.

  • channels: List[int].
    List of the number of channels in each layer.

  • pooling_type: Literal['max', 'avg'].
    Type of pooling to be used for the DownScale layers

  • scale: int.
    Scale of the image for each stage.

  • conv_kernel_size: int.
    Kernel size of the convolutional layers.

  • conv_padding: int.
    Padding of the convolutional layers.

  • method: Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "conv"].
    Method of computing the initial upscale of the image. If conv is selected, uses a convolutional transposed layer. Else, uses a nn.functional.upsample function with the corresponding method.

  • crop: bool.
    If enabled, the output each upscale layer will be cropped to the native size of the UpScaled image.

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vision_models_playground-0.2.3.tar.gz (68.3 kB view details)

Uploaded Source

File details

Details for the file vision_models_playground-0.2.3.tar.gz.

File metadata

  • Download URL: vision_models_playground-0.2.3.tar.gz
  • Upload date:
  • Size: 68.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/45.2.0 requests-toolbelt/0.8.0 tqdm/4.30.0 CPython/3.8.10

File hashes

Hashes for vision_models_playground-0.2.3.tar.gz
Algorithm Hash digest
SHA256 255bafb517927d375e68aea7a574bad516d7aca87e0435bda7bcb5167e3db4bf
MD5 90af8741741137a853c387b78cada90a
BLAKE2b-256 72cd2a301c45123e87334e73d262d98e07628ac1dada4528e25af44430330a10

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page