Skip to main content

Compilation of Torch Modules from various ML papers

Project description

Torch-Modules-Compilation

A compilation of implementations of various ML papers, especially in computer vision. This contains some self-implementations and unofficial & official implementations. More to be added.

Install

$ pip install torch-modules-compilation

Table of Contents

Modules/Blocks

Bottleneck Residual Block

image

Your basic bottleneck residual block in ResNets. Image from the paper "Deep Residual Learning for Image Recognition"

Parameters

in_channels (int): number of input channels

bottleneck_channels (int): number of bottleneck channels; usually less than the number of bottleneck channels

dropout (float): dropout rate; performed after every convolution

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 256, 16, 16) # (batch_size, channels, height, width)
block = modules.BottleneckResBlock(in_channels=256, bottleneck_channels=64)

block(x).shape # (32, 256, 16, 16)

Depthwise Seperable Convolution

image

A depthwise seperable convolution; consists of a depthwise convolution and a pointwise convolution. Used in MobileNets and used in the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications". Image also from this paper.

Parameters:

in_channels (int): Number of input channels

out_channels (int): Number of output channels

kernel_size (int): Size of depthwise convolution kernel

stride (int): Stride of depthwise convolution

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.DepthwiseSepConv(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)

block(x).shape # (32, 128, 16, 16)

SAGAN self-attention module

image

A feature map self-attention module used in SAGAN; "Self-Attention Generative Adversarial Networks". Image also from this paper. This code implementation was copied and modified from https://github.com/rosinality/sagan-pytorch/blob/master/model.py#L82 under Apache 2.0 License. Modification removes spectral initalization.

Parameters

in_channels (int): Number of input channels

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.FeatureMapSelfAttention(in_channels=64)

block(x).shape # (32, 64, 16, 16)

Global-Local Attention Module

image

An convolutional attention module introduced in the paper "All the attention you need: Global-local, spatial-channel attention for image retrieval.". Image also from this paper.

Parameters

in_channels (int): number of channels of the input feature map

num_reduced_channels (int): number of channels that the local and global spatial attention modules will reduce the input feature map. Refer to figures 3 and 5 in the paper.

feaure_map_size (int): height/width of the feature map. The height/width of the input feature maps must be at least 7, due to the 7x7 convolution (3x3 dilated conv) in the module.

kernel_size (int): scope of the inter-channel attention

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)

block = modules.GLAM(in_channels=64, num_reduced_channels=48, feature_map_size=16, kernel_size=5)
# height and width is equal to feature_map_size

block(x).shape # (32, 64, 16, 16)

Global Context Module

image

A sort of self-attention (non-local) block on feature maps. Implementation of "GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond".

Parameters

input_channels (int): Number of input channels

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)

block = modules.GlobalContextModule(input_channels=64)

block(x).shape # (32, 64, 16, 16)

LFSA Tokenizer and Refinement Block

image

Implementation of the tokenizer in "Learning Token-Based Representation for Image Retrieval" This are two modules: The tokenizer module that converts feature maps from a CNN (in the paper's case, feature maps from a local-feature-self-attention module) and tokenizes them "into L visual tokens". This is used prior to the refinement block as described in the paper. The refinement block "enhance[s] the obtained visual tokens with self-attention and cross-attention."

Parameters

LFSA Tokenizer

in_channels (int): number of input channels

num_att_maps (int): number of tokens to tokenize the input into; also the number of channels used by the spatial attention

Refinement Block

d_model (int): dimensionality/channels of input

nhead (int): number of attention heads in the transformer

dim_feedforward (int): number of hidden dimensions in the feedforward layers

dropout (int): dropout rate

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)

tokenizer = modules.LFSATokenizer(in_channels=64, num_att_maps=48)
refinement_block = modules.RefinementBlock(d_model=64, nhead=2, dim_feedforward=48*4, dropout=0.1)

visual_tokens, cnn_output = tokenizer(x)
print(visual_tokens.shape) # (32, 48, 64)
print(cnn_output.shape) # (32, 16*16, 64)

output = refinement_block(visual_tokens, cnn_output)
print(output.shape) # (32, 48, 64)

Parameter-Free Channel Attention (PFCA)

image

A channel attention module for convolutional feature maps without any trainable parameters. Used in and image from the paper "PARAMETER-FREE CHANNEL ATTENTION FOR IMAGE CLASSIFICATION AND SUPER-RESOLUTION".

Parameters

feature_map_size (int): Length/width of the input feature map

_lambda (float): A hyperparameter that is added to the variance (default: 1e-4)

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.ParameterFreeChannelAttention(feature_map_size=16)

block(x).shape # (32, 64, 16, 16)

Patch Merger

image

Merges N tokens into M tokens in transformer models. Typically added in-between transformer layers. Introduced in the paper "LEARNING TO MERGE TOKENS IN VISION TRANSFORMERS". Image from this paper. Copied from lucidrains' repo under the MIT license.

Parameters

dim (int): dimensionality/channels of the tokens

output_tokens (int): number of output merged tokens

norm (bool): normalize the input before merging

scale (bool): scale the attention matrix by the square root of dim (for numerical stability)

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16) # (batch_size, seq_length, channels)
block = modules.PatchMerger(dim=16, output_tokens=48, scale=True)

block(x).shape # (32, 48, 16)

ResBlock

image

Your basic residual block. Used in ResNets. Image from original paper "Deep Residual Learning for Image Recognition"

Parameters

in_channels (int): number of input channels

kernel_size (int): kernel size

dropout (float): dropout rate

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, seq_length, channels)
block = modules.ResBlock(in_channels=64, kernel_size=3, dropout=0.2)

block(x).shape # (32, 64, 16, 16)

Up/Down sample ResBlock

Composed of several residual blocks and a down/up sampling at the end; adapted from Stable Diffusion's ResnetBlock.

Parameters

in_channels (int): number of input channels

out_channels (int): number of output channels

num_groups (int): number of groups for Group Normalization

num_layers (int): number of residual blocks

dropout (float): dropout rate

sample (str): One of "down", "up", or "none". For downsampling 2x, use "down". For upsampling 2x, use "up". Use "none" for no down/up sampling.

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 64, 96, 96) # (batch_size, channels, height, width)
block = modules.ResBlockUpDownSample(
    in_channels=64, 
    out_channels=128, 
    num_groups=8, 
    num_layers=2, 
    dropout=0.1, 
    sample='down'
)

block(x).shape # (32, 128, 48, 48)

Residual MLP Block

An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 5 in the paper).

Parameters

dim (int): number of input dimensions

ic_first (bool): normalize and dropout at the start

dropout (float): dropout rate

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_block(dim=96, ic_first=True, dropout=0.1)

block(x).shape # (32, 96)

Residual MLP Downsampling Block

An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 6 in the paper).

Parameters

dim (int): number of input dimensions

downsample_dim (int): number of output dimensions

dropout (float): dropout rate

Usage

from torch_modules_compilation import modules

x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_downsample(dim=96, downsample_dim=48, dropout=0.1)

block(x).shape # (32, 48)

Transformer Encoder Layer

Standard transformer encoder layer with queries, keys, and values as inputs.

Parameters

d_model (int): model dimensionality

nhead (int): number of attention heads

dim_feedforward (int): number of hidden dimensions in the feedforward layers

dropout (float): dropout rate

kdim (int, optional): dimensions of the keys

vdim (int, optional): dimensions of the values

Usage

from torch_modules_compilation import modules

queries = torch.randn(32, 20, 64) # (batch_size, seq_length, dim)
keys = torch.randn(32, 19, 48) # (batch_size, seq_length, dim)
values = torch.randn(32, 19, 96) # (batch_size, seq_length, dim)

block = modules.TransformerEncoderLayer(
    d_model=64,
    nhead=8, 
    dim_feedforward=256,
    dropout=0.2,
    kdim=48,
    vdim=96
)

block(queries, keys, values).shape # (32, 20, 64)

UNet Encoder and Decoder

image

Standard UNet implementation. From the paper U-Net: Convolutional Networks for Biomedical Image Segmentation.

Parameters

UNet Encoder

channels (list): A list containing the number of channels in the encoder. E.g [3, 64, 128, 256]

dropout (float): dropout rate

UNet Decoder

channels (list of ints): A list containing the number of channels in the encoder. E.g. [256, 128, 64, 3]

dropout (float): dropout rate

Usage

from torch_modules_compilation import modules

images = torch.randn(16, 3, 224, 224) # (batch_size, channels, height, width)

unet_encoder = modules.UnetEncoder(channels=[3,64,128,256], dropout=0.1)
unet_decoder = modules.UnetDecoder(channels=[256,128,64,3], dropout=0.1)

encoder_features = unet_encoder(images)

output = unet_decoder(encoder_features)
print(output.shape) # (16, 64, 224, 224)

Squeeze-Excitation Module

image

Module that computes channel-wise interactions in a feature map. From Squeeze-and-Excitation Networks.

Parameters

in_channels (int): Number of input channels

reduced_channels (int): Number of channels to reduce to in the "squeeze" part of the module

feature_map_size (int): height/width of the feature map

Usage

from torch_modules_compilation import modules

feature_maps = torch.randn(16, 128, 64, 64) # (batch_size, channels, height, width)
se_module = modules.SEModule(in_channels=128, reduced_channels=32, feature_map_size=64)

se_module(feature_maps) # shape (16, 128, 64, 64); same as input

Token Learner

image

Module designed for reducing and generating visual tokens given a feature map. From TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?

Parameters

in_channels (int): Number of input channels

num_tokens (int): Number of tokens to reduce to

Usage

feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
token_learner = TokenLearner(in_channels=16, num_tokens=50) # reduce tokens from 10*10 to 50

token_learner(feature_maps) # shape (2, 50, 16)

Triplet Attention

image

Computes attention in a feature map across all three dimensions (channel and both spatial dims). From Rotate to Attend: Convolutional Triplet Attention Module.

Parameters

in_channels (int): Number of input channels

height (int): height of feature map

width (int): width of feature map

kernel_size (int): kernel size of the convolutions. Default: 7

Usage

feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
triplet_attention = TripletAttention(in_channels=16, height=10, width=10)

triplet_attention(feature_maps) # shape (2, 16, 10, 10); same as input

License

Unless specified, some of these modules are licensed under various licenses and/or copied from other repositories, such as MIT and Apache. Take note of these licenses when using these code in your work. The rest are of my own implementation, which is under the MIT license. See this repo's license file

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_modules_compilation-0.0.2.tar.gz (19.3 kB view details)

Uploaded Source

Built Distribution

torch_modules_compilation-0.0.2-py3-none-any.whl (25.8 kB view details)

Uploaded Python 3

File details

Details for the file torch_modules_compilation-0.0.2.tar.gz.

File metadata

File hashes

Hashes for torch_modules_compilation-0.0.2.tar.gz
Algorithm Hash digest
SHA256 5f809c2055803e48a4e29276250c2c9d9eef750f76bd18db67d01a0ae2799aef
MD5 e7512f58e4e6949cc57ac94393462067
BLAKE2b-256 4245f5807baba20394dc1f9231aa7639bde81472553ecafaf1521d26c362ea1f

See more details on using hashes here.

File details

Details for the file torch_modules_compilation-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_modules_compilation-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 edde92048c4ec37d54584f422478821402e334435dd286857c71a21385ec1f44
MD5 c490248a8d98a41dfda73167e9584c12
BLAKE2b-256 01a49b6b6a925db11b8467de577502bb56df96add63fe8f7b87cd1cbf228e811

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page