Compilation of Torch Modules from various ML papers
Project description
Torch-Modules-Compilation
A compilation of implementations of various ML papers, especially in computer vision. This contains some self-implementations and unofficial & official implementations. More to be added.
Install
$ pip install torch-modules-compilation
Table of Contents
-
- Bottleneck Residual Block
- Depthwise Seperable Convolution
- SAGAN self-attention module
- Global-Local Attention Module
- Global Context Module
- LFSA Tokenizer and Refinement Block
- Parameter-Free Channel Attention (PFCA)
- Patch Merger
- ResBlock
- Up/Down sample ResBlock
- Residual MLP Block
- Residual MLP Downsampling Block
- Transformer Encoder Layer
- UNet Encoder and Decoder
- Squeeze-Excitation Module
- Token Learner
- Triplet Attention
Modules/Blocks
Bottleneck Residual Block
Your basic bottleneck residual block in ResNets. Image from the paper "Deep Residual Learning for Image Recognition"
Parameters
in_channels
(int): number of input channels
bottleneck_channels
(int): number of bottleneck channels; usually less than the number of bottleneck channels
dropout
(float): dropout rate; performed after every convolution
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 256, 16, 16) # (batch_size, channels, height, width)
block = modules.BottleneckResBlock(in_channels=256, bottleneck_channels=64)
block(x).shape # (32, 256, 16, 16)
Depthwise Seperable Convolution
A depthwise seperable convolution; consists of a depthwise convolution and a pointwise convolution. Used in MobileNets and used in the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications". Image also from this paper.
Parameters:
in_channels
(int): Number of input channels
out_channels
(int): Number of output channels
kernel_size
(int): Size of depthwise convolution kernel
stride
(int): Stride of depthwise convolution
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.DepthwiseSepConv(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
block(x).shape # (32, 128, 16, 16)
SAGAN self-attention module
A feature map self-attention module used in SAGAN; "Self-Attention Generative Adversarial Networks". Image also from this paper. This code implementation was copied and modified from https://github.com/rosinality/sagan-pytorch/blob/master/model.py#L82 under Apache 2.0 License. Modification removes spectral initalization.
Parameters
in_channels
(int): Number of input channels
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.FeatureMapSelfAttention(in_channels=64)
block(x).shape # (32, 64, 16, 16)
Global-Local Attention Module
An convolutional attention module introduced in the paper "All the attention you need: Global-local, spatial-channel attention for image retrieval.". Image also from this paper.
Parameters
in_channels
(int): number of channels of the input feature map
num_reduced_channels
(int): number of channels that the local and global spatial attention modules will reduce the input feature map. Refer to figures 3 and 5 in the paper.
feaure_map_size
(int): height/width of the feature map. The height/width of the input feature maps must be at least 7, due to the 7x7 convolution (3x3 dilated conv) in the module.
kernel_size
(int): scope of the inter-channel attention
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.GLAM(in_channels=64, num_reduced_channels=48, feature_map_size=16, kernel_size=5)
# height and width is equal to feature_map_size
block(x).shape # (32, 64, 16, 16)
Global Context Module
A sort of self-attention (non-local) block on feature maps. Implementation of "GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond".
Parameters
input_channels
(int): Number of input channels
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.GlobalContextModule(input_channels=64)
block(x).shape # (32, 64, 16, 16)
LFSA Tokenizer and Refinement Block
Implementation of the tokenizer in "Learning Token-Based Representation for Image Retrieval" This are two modules: The tokenizer module that converts feature maps from a CNN (in the paper's case, feature maps from a local-feature-self-attention module) and tokenizes them "into L visual tokens". This is used prior to the refinement block as described in the paper. The refinement block "enhance[s] the obtained visual tokens with self-attention and cross-attention."
Parameters
LFSA Tokenizer
in_channels
(int): number of input channels
num_att_maps
(int): number of tokens to tokenize the input into; also the number of channels used by the spatial attention
Refinement Block
d_model
(int): dimensionality/channels of input
nhead
(int): number of attention heads in the transformer
dim_feedforward
(int): number of hidden dimensions in the feedforward layers
dropout
(int): dropout rate
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
tokenizer = modules.LFSATokenizer(in_channels=64, num_att_maps=48)
refinement_block = modules.RefinementBlock(d_model=64, nhead=2, dim_feedforward=48*4, dropout=0.1)
visual_tokens, cnn_output = tokenizer(x)
print(visual_tokens.shape) # (32, 48, 64)
print(cnn_output.shape) # (32, 16*16, 64)
output = refinement_block(visual_tokens, cnn_output)
print(output.shape) # (32, 48, 64)
Parameter-Free Channel Attention (PFCA)
A channel attention module for convolutional feature maps without any trainable parameters. Used in and image from the paper "PARAMETER-FREE CHANNEL ATTENTION FOR IMAGE CLASSIFICATION AND SUPER-RESOLUTION".
Parameters
feature_map_size
(int): Length/width of the input feature map
_lambda
(float): A hyperparameter that is added to the variance (default: 1e-4)
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.ParameterFreeChannelAttention(feature_map_size=16)
block(x).shape # (32, 64, 16, 16)
Patch Merger
Merges N tokens into M tokens in transformer models. Typically added in-between transformer layers. Introduced in the paper "LEARNING TO MERGE TOKENS IN VISION TRANSFORMERS". Image from this paper. Copied from lucidrains' repo under the MIT license.
Parameters
dim
(int): dimensionality/channels of the tokens
output_tokens
(int): number of output merged tokens
norm
(bool): normalize the input before merging
scale
(bool): scale the attention matrix by the square root of dim (for numerical stability)
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16) # (batch_size, seq_length, channels)
block = modules.PatchMerger(dim=16, output_tokens=48, scale=True)
block(x).shape # (32, 48, 16)
ResBlock
Your basic residual block. Used in ResNets. Image from original paper "Deep Residual Learning for Image Recognition"
Parameters
in_channels
(int): number of input channels
kernel_size
(int): kernel size
dropout
(float): dropout rate
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, seq_length, channels)
block = modules.ResBlock(in_channels=64, kernel_size=3, dropout=0.2)
block(x).shape # (32, 64, 16, 16)
Up/Down sample ResBlock
Composed of several residual blocks and a down/up sampling at the end; adapted from Stable Diffusion's ResnetBlock.
Parameters
in_channels
(int): number of input channels
out_channels
(int): number of output channels
num_groups
(int): number of groups for Group Normalization
num_layers
(int): number of residual blocks
dropout
(float): dropout rate
sample
(str): One of "down", "up", or "none". For downsampling 2x, use "down". For upsampling 2x, use "up". Use "none" for no down/up sampling.
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 64, 96, 96) # (batch_size, channels, height, width)
block = modules.ResBlockUpDownSample(
in_channels=64,
out_channels=128,
num_groups=8,
num_layers=2,
dropout=0.1,
sample='down'
)
block(x).shape # (32, 128, 48, 48)
Residual MLP Block
An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 5 in the paper).
Parameters
dim
(int): number of input dimensions
ic_first
(bool): normalize and dropout at the start
dropout
(float): dropout rate
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_block(dim=96, ic_first=True, dropout=0.1)
block(x).shape # (32, 96)
Residual MLP Downsampling Block
An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 6 in the paper).
Parameters
dim
(int): number of input dimensions
downsample_dim
(int): number of output dimensions
dropout
(float): dropout rate
Usage
from torch_modules_compilation import modules
x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_downsample(dim=96, downsample_dim=48, dropout=0.1)
block(x).shape # (32, 48)
Transformer Encoder Layer
Standard transformer encoder layer with queries, keys, and values as inputs.
Parameters
d_model
(int): model dimensionality
nhead
(int): number of attention heads
dim_feedforward
(int): number of hidden dimensions in the feedforward layers
dropout
(float): dropout rate
kdim
(int, optional): dimensions of the keys
vdim
(int, optional): dimensions of the values
Usage
from torch_modules_compilation import modules
queries = torch.randn(32, 20, 64) # (batch_size, seq_length, dim)
keys = torch.randn(32, 19, 48) # (batch_size, seq_length, dim)
values = torch.randn(32, 19, 96) # (batch_size, seq_length, dim)
block = modules.TransformerEncoderLayer(
d_model=64,
nhead=8,
dim_feedforward=256,
dropout=0.2,
kdim=48,
vdim=96
)
block(queries, keys, values).shape # (32, 20, 64)
UNet Encoder and Decoder
Standard UNet implementation. From the paper U-Net: Convolutional Networks for Biomedical Image Segmentation.
Parameters
UNet Encoder
channels
(list): A list containing the number of channels in the encoder. E.g [3, 64, 128, 256]
dropout
(float): dropout rate
UNet Decoder
channels
(list of ints): A list containing the number of channels in the encoder. E.g. [256, 128, 64, 3]
dropout
(float): dropout rate
Usage
from torch_modules_compilation import modules
images = torch.randn(16, 3, 224, 224) # (batch_size, channels, height, width)
unet_encoder = modules.UnetEncoder(channels=[3,64,128,256], dropout=0.1)
unet_decoder = modules.UnetDecoder(channels=[256,128,64,3], dropout=0.1)
encoder_features = unet_encoder(images)
output = unet_decoder(encoder_features)
print(output.shape) # (16, 64, 224, 224)
Squeeze-Excitation Module
Module that computes channel-wise interactions in a feature map. From Squeeze-and-Excitation Networks.
Parameters
in_channels
(int): Number of input channels
reduced_channels
(int): Number of channels to reduce to in the "squeeze" part of the module
feature_map_size
(int): height/width of the feature map
Usage
from torch_modules_compilation import modules
feature_maps = torch.randn(16, 128, 64, 64) # (batch_size, channels, height, width)
se_module = modules.SEModule(in_channels=128, reduced_channels=32, feature_map_size=64)
se_module(feature_maps) # shape (16, 128, 64, 64); same as input
Token Learner
Module designed for reducing and generating visual tokens given a feature map. From TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?
Parameters
in_channels
(int): Number of input channels
num_tokens
(int): Number of tokens to reduce to
Usage
feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
token_learner = TokenLearner(in_channels=16, num_tokens=50) # reduce tokens from 10*10 to 50
token_learner(feature_maps) # shape (2, 50, 16)
Triplet Attention
Computes attention in a feature map across all three dimensions (channel and both spatial dims). From Rotate to Attend: Convolutional Triplet Attention Module.
Parameters
in_channels
(int): Number of input channels
height
(int): height of feature map
width
(int): width of feature map
kernel_size
(int): kernel size of the convolutions. Default: 7
Usage
feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
triplet_attention = TripletAttention(in_channels=16, height=10, width=10)
triplet_attention(feature_maps) # shape (2, 16, 10, 10); same as input
License
Unless specified, some of these modules are licensed under various licenses and/or copied from other repositories, such as MIT and Apache. Take note of these licenses when using these code in your work. The rest are of my own implementation, which is under the MIT license. See this repo's license file
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_modules_compilation-0.0.2.tar.gz
.
File metadata
- Download URL: torch_modules_compilation-0.0.2.tar.gz
- Upload date:
- Size: 19.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5f809c2055803e48a4e29276250c2c9d9eef750f76bd18db67d01a0ae2799aef |
|
MD5 | e7512f58e4e6949cc57ac94393462067 |
|
BLAKE2b-256 | 4245f5807baba20394dc1f9231aa7639bde81472553ecafaf1521d26c362ea1f |
File details
Details for the file torch_modules_compilation-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: torch_modules_compilation-0.0.2-py3-none-any.whl
- Upload date:
- Size: 25.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | edde92048c4ec37d54584f422478821402e334435dd286857c71a21385ec1f44 |
|
MD5 | c490248a8d98a41dfda73167e9584c12 |
|
BLAKE2b-256 | 01a49b6b6a925db11b8467de577502bb56df96add63fe8f7b87cd1cbf228e811 |