This repository contains an implementation of many attention mechanism models.
Project description
Attention-Mechanism-Pytorch
This repository contains an implementation of many attention mechanism models.
Change Log
- Published Initial Attention Models, 2024-8-12.
目录
1. External Attention Usage
1.1. Paper
"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks"
1.2. Overview
1.3. Usage Code
from AttentionMechanism.model.attention.ExternalAttention import ExternalAttention
import torch
input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)
2. Self Attention Usage
2.1. Paper
1.2. Overview
1.3. Usage Code
from AttentionMechanism.model.attention.SelfAttention import ScaledDotProductAttention
import torch
input=torch.randn(50,49,512)
sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)
3. Simplified Self Attention Usage
3.1. Paper
SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks (ICML 2021)
3.2. Overview
3.3. Usage Code
from AttentionMechanism.model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
import torch
input=torch.randn(50,49,512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output=ssa(input,input,input)
print(output.shape)
4. Squeeze-and-Excitation Attention Usage
4.1. Paper
"Squeeze-and-Excitation Networks"
4.2. Overview
4.3. Usage Code
from AttentionMechanism.model.attention.SEAttention import SEAttention
import torch
input=torch.randn(50,512,7,7)
se = SEAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)
5. SK Attention Usage
5.1. Paper
5.2. Overview
5.3. Usage Code
from AttentionMechanism.model.attention.SKAttention import SKAttention
import torch
input=torch.randn(50,512,7,7)
se = SKAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)
6. CBAM Attention Usage
6.1. Paper
"CBAM: Convolutional Block Attention Module"
6.2. Overview
6.3. Usage Code
from AttentionMechanism.model.attention.CBAM import CBAMBlock
import torch
input=torch.randn(50,512,7,7)
kernel_size=input.shape[2]
cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)
output=cbam(input)
print(output.shape)
7. BAM Attention Usage
7.1. Paper
"BAM: Bottleneck Attention Module"
7.2. Overview
7.3. Usage Code
from AttentionMechanism.model.attention.BAM import BAMBlock
import torch
input=torch.randn(50,512,7,7)
bam = BAMBlock(channel=512,reduction=16,dia_val=2)
output=bam(input)
print(output.shape)
8. ECA Attention Usage
8.1. Paper
"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks"
8.2. Overview
8.3. Usage Code
from AttentionMechanism.model.attention.ECAAttention import ECAAttention
import torch
input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)
9. DANet Attention Usage
9.1. Paper
"Dual Attention Network for Scene Segmentation"
9.2. Overview
9.3. Usage Code
from AttentionMechanism.model.attention.DANet import DAModule
import torch
input=torch.randn(50,512,7,7)
danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
print(danet(input).shape)
10. Pyramid Split Attention Usage
10.1. Paper
"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network"
10.2. Overview
10.3. Usage Code
from AttentionMechanism.model.attention.PSA import PSA
import torch
input=torch.randn(50,512,7,7)
psa = PSA(channel=512,reduction=8)
output=psa(input)
print(output.shape)
11. Efficient Multi-Head Self-Attention Usage
11.1. Paper
"ResT: An Efficient Transformer for Visual Recognition"
11.2. Overview
11.3. Usage Code
from AttentionMechanism.model.attention.EMSA import EMSA
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,64,512)
emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)
output=emsa(input,input,input)
print(output.shape)
12. Shuffle Attention Usage
12.1. Paper
"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS"
12.2. Overview
12.3. Usage Code
from AttentionMechanism.model.attention.ShuffleAttention import ShuffleAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,512,7,7)
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)
13. MUSE Attention Usage
13.1. Paper
"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning"
13.2. Overview
13.3. Usage Code
from AttentionMechanism.model.attention.MUSEAttention import MUSEAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,49,512)
sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)
14. SGE Attention Usage
14.1. Paper
Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks
14.2. Overview
14.3. Usage Code
from AttentionMechanism.model.attention.SGE import SpatialGroupEnhance
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,512,7,7)
sge = SpatialGroupEnhance(groups=8)
output=sge(input)
print(output.shape)
15. A2 Attention Usage
15.1. Paper
A2-Nets: Double Attention Networks
15.2. Overview
15.3. Usage Code
from AttentionMechanism.model.attention.A2Atttention import DoubleAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,512,7,7)
a2 = DoubleAttention(512,128,128,True)
output=a2(input)
print(output.shape)
16. AFT Attention Usage
16.1. Paper
16.2. Overview
16.3. Usage Code
from AttentionMechanism.model.attention.AFT import AFT_FULL
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,49,512)
aft_full = AFT_FULL(d_model=512, n=49)
output=aft_full(input)
print(output.shape)
17. Outlook Attention Usage
17.1. Paper
VOLO: Vision Outlooker for Visual Recognition"
17.2. Overview
17.3. Usage Code
from AttentionMechanism.model.attention.OutlookAttention import OutlookAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,28,28,512)
outlook = OutlookAttention(dim=512)
output=outlook(input)
print(output.shape)
18. ViP Attention Usage
18.1. Paper
Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition"
18.2. Overview
18.3. Usage Code
from AttentionMechanism.model.attention.ViP import WeightedPermuteMLP
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(64,8,8,512)
seg_dim=8
vip=WeightedPermuteMLP(512,seg_dim)
out=vip(input)
print(out.shape)
19. CoAtNet Attention Usage
19.1. Paper
CoAtNet: Marrying Convolution and Attention for All Data Sizes"
19.2. Overview
None
19.3. Usage Code
from AttentionMechanism.model.attention.CoAtNet import CoAtNet
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(1,3,224,224)
mbconv=CoAtNet(in_ch=3,image_size=224)
out=mbconv(input)
print(out.shape)
20. HaloNet Attention Usage
20.1. Paper
Scaling Local Self-Attention for Parameter Efficient Visual Backbones"
20.2. Overview
20.3. Usage Code
from AttentionMechanism.model.attention.HaloAttention import HaloAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(1,512,8,8)
halo = HaloAttention(dim=512,
block_size=2,
halo_size=1,)
output=halo(input)
print(output.shape)
21. Polarized Self-Attention Usage
21.1. Paper
Polarized Self-Attention: Towards High-quality Pixel-wise Regression"
21.2. Overview
21.3. Usage Code
from AttentionMechanism.model.attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(1,512,7,7)
psa = SequentialPolarizedSelfAttention(channel=512)
output=psa(input)
print(output.shape)
22. CoTAttention Usage
22.1. Paper
Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26
22.2. Overview
22.3. Usage Code
from AttentionMechanism.model.attention.CoTAttention import CoTAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,512,7,7)
cot = CoTAttention(dim=512,kernel_size=3)
output=cot(input)
print(output.shape)
23. Residual Attention Usage
23.1. Paper
Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021
23.2. Overview
23.3. Usage Code
from AttentionMechanism.model.attention.ResidualAttention import ResidualAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,512,7,7)
resatt = ResidualAttention(channel=512,num_class=1000,la=0.2)
output=resatt(input)
print(output.shape)
24. S2 Attention Usage
24.1. Paper
S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02
24.2. Overview
24.3. Usage Code
from AttentionMechanism.model.attention.S2Attention import S2Attention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,512,7,7)
s2att = S2Attention(channels=512)
output=s2att(input)
print(output.shape)
25. GFNet Attention Usage
25.1. Paper
Global Filter Networks for Image Classification---arXiv 2021.07.01
25.2. Overview
25.3. Usage Code - Implemented by Wenliang Zhao (Author)
from AttentionMechanism.model.attention.gfnet import GFNet
import torch
from torch import nn
from torch.nn import functional as F
x = torch.randn(1, 3, 224, 224)
gfnet = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
out = gfnet(x)
print(out.shape)
26. TripletAttention Usage
26.1. Paper
Rotate to Attend: Convolutional Triplet Attention Module---CVPR 2021
26.2. Overview
26.3. Usage Code - Implemented by digantamisra98
from AttentionMechanism.model.attention.TripletAttention import TripletAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn(50,512,7,7)
triplet = TripletAttention()
output=triplet(input)
print(output.shape)
27. Coordinate Attention Usage
27.1. Paper
Coordinate Attention for Efficient Mobile Network Design---CVPR 2021
27.2. Overview
27.3. Usage Code - Implemented by Andrew-Qibin
from AttentionMechanism.model.attention.CoordAttention import CoordAtt
import torch
from torch import nn
from torch.nn import functional as F
inp=torch.rand([2, 96, 56, 56])
inp_dim, oup_dim = 96, 96
reduction=32
coord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)
output=coord_attention(inp)
print(output.shape)
28. MobileViT Attention Usage
28.1. Paper
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05
28.2. Overview
28.3. Usage Code
from AttentionMechanism.model.attention.MobileViTAttention import MobileViTAttention
import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
m=MobileViTAttention()
input=torch.randn(1,3,49,49)
output=m(input)
print(output.shape) #output:(1,3,49,49)
29. ParNet Attention Usage
29.1. Paper
Non-deep Networks---ArXiv 2021.10.20
29.2. Overview
29.3. Usage Code
from AttentionMechanism.model.attention.ParNetAttention import *
import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn(50,512,7,7)
pna = ParNetAttention(channel=512)
output=pna(input)
print(output.shape) #50,512,7,7
30. UFO Attention Usage
30.1. Paper
UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29
30.2. Overview
30.3. Usage Code
from AttentionMechanism.model.attention.UFOAttention import *
import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn(50,49,512)
ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
output=ufo(input,input,input)
print(output.shape) #[50, 49, 512]
31. ACmix Attention Usage
31.1. Paper
On the Integration of Self-Attention and Convolution
31.2. Usage Code
from AttentionMechanism.model.attention.ACmix import ACmix
import torch
if __name__ == '__main__':
input=torch.randn(50,256,7,7)
acmix = ACmix(in_planes=256, out_planes=256)
output=acmix(input)
print(output.shape)
32. MobileViTv2 Attention Usage
32.1. Paper
Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06
32.2. Overview
32.3. Usage Code
from AttentionMechanism.model.attention.MobileViTv2Attention import MobileViTv2Attention
import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn(50,49,512)
sa = MobileViTv2Attention(d_model=512)
output=sa(input)
print(output.shape)
33. DAT Attention Usage
33.1. Paper
Vision Transformer with Deformable Attention---CVPR2022
33.2. Usage Code
from AttentionMechanism.model.attention.DAT import DAT
import torch
if __name__ == '__main__':
input=torch.randn(1,3,224,224)
model = DAT(
img_size=224,
patch_size=4,
num_classes=1000,
expansion=4,
dim_stem=96,
dims=[96, 192, 384, 768],
depths=[2, 2, 6, 2],
stage_spec=[['L', 'S'], ['L', 'S'], ['L', 'D', 'L', 'D', 'L', 'D'], ['L', 'D']],
heads=[3, 6, 12, 24],
window_sizes=[7, 7, 7, 7] ,
groups=[-1, -1, 3, 6],
use_pes=[False, False, True, True],
dwc_pes=[False, False, False, False],
strides=[-1, -1, 1, 1],
sr_ratios=[-1, -1, -1, -1],
offset_range_factor=[-1, -1, 2, 2],
no_offs=[False, False, False, False],
fixed_pes=[False, False, False, False],
use_dwc_mlps=[False, False, False, False],
use_conv_patches=False,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
)
output=model(input)
print(output[0].shape)
34. CrossFormer Attention Usage
34.1. Paper
CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION---ICLR 2022
34.2. Usage Code
from AttentionMechanism.model.attention.Crossformer import CrossFormer
import torch
if __name__ == '__main__':
input=torch.randn(1,3,224,224)
model = CrossFormer(img_size=224,
patch_size=[4, 8, 16, 32],
in_chans= 3,
num_classes=1000,
embed_dim=48,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
group_size=[7, 7, 7, 7],
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
drop_path_rate=0.1,
ape=False,
patch_norm=True,
use_checkpoint=False,
merge_size=[[2, 4], [2,4], [2, 4]]
)
output=model(input)
print(output.shape)
35. MOATransformer Attention Usage
35.1. Paper
Aggregating Global Features into Local Vision Transformer
35.2. Usage Code
from AttentionMechanism.model.attention.MOATransformer import MOATransformer
import torch
if __name__ == '__main__':
input=torch.randn(1,3,224,224)
model = MOATransformer(
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6],
num_heads=[3, 6, 12],
window_size=14,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
drop_path_rate=0.1,
ape=False,
patch_norm=True,
use_checkpoint=False
)
output=model(input)
print(output.shape)
36. CrissCrossAttention Attention Usage
36.1. Paper
CCNet: Criss-Cross Attention for Semantic Segmentation
36.2. Usage Code
from AttentionMechanism.model.attention.CrissCrossAttention import CrissCrossAttention
import torch
if __name__ == '__main__':
input=torch.randn(3, 64, 7, 7)
model = CrissCrossAttention(64)
outputs = model(input)
print(outputs.shape)
37. Axial_attention Attention Usage
37.1. Paper
Axial Attention in Multidimensional Transformers
37.2. Usage Code
from AttentionMechanism.model.attention.Axial_attention import AxialImageTransformer
import torch
if __name__ == '__main__':
input=torch.randn(3, 128, 7, 7)
model = AxialImageTransformer(
dim = 128,
depth = 12,
reversible = True
)
outputs = model(input)
print(outputs.shape)
38. Frequency Channel Attention Usage
38.1. Paper
FcaNet: Frequency Channel Attention Networks (ICCV 2021)
38.2. Overview
38.3. Usage Code
from AttentionMechanism.model.attention.FCA import MultiSpectralAttentionLayer
import torch
if __name__ == "__main__":
input = torch.randn(32, 128, 64, 64) # (b, c, h, w)
fca_layer = MultiSpectralAttentionLayer(channel = 128, dct_h = 64, dct_w = 64, reduction = 16, freq_sel_method = 'top16')
output = fca_layer(input)
print(output.shape)
39. Attention Augmented Convolutional Networks Usage
39.1. Paper
Attention Augmented Convolutional Networks (ICCV 2019)
39.2. Overview
39.3. Usage Code
from AttentionMechanism.model.attention.AAAttention import AugmentedConv
import torch
if __name__ == "__main__":
input = torch.randn((16, 3, 32, 32))
augmented_conv = AugmentedConv(in_channels=3, out_channels=64, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16)
output = augmented_conv(input)
print(output.shape)
40. Global Context Attention Usage
40.1. Paper
GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond (ICCVW 2019 Best Paper)
Global Context Networks (TPAMI 2020)
40.2. Overview
40.3. Usage Code
from AttentionMechanism.model.attention.GCAttention import GCModule
import torch
if __name__ == "__main__":
input = torch.randn(16, 64, 32, 32)
gc_layer = GCModule(64)
output = gc_layer(input)
print(output.shape)
41. Linear Context Transform Attention Usage
41.1. Paper
Linear Context Transform Block (AAAI 2020)
41.2. Overview
41.3. Usage Code
from AttentionMechanism.model.attention.LCTAttention import LCT
import torch
if __name__ == "__main__":
x = torch.randn(16, 64, 32, 32)
attn = LCT(64, 8)
y = attn(x)
print(y.shape)
42. Gated Channel Transformation Usage
42.1. Paper
Gated Channel Transformation for Visual Recognition (CVPR 2020)
42.2. Overview
42.3. Usage Code
from AttentionMechanism.model.attention.GCTAttention import GCT
import torch
if __name__ == "__main__":
input = torch.randn(16, 64, 32, 32)
gct_layer = GCT(64)
output = gct_layer(input)
print(output.shape)
43. Gaussian Context Attention Usage
43.1. Paper
Gaussian Context Transformer (CVPR 2021)
43.2. Overview
43.3. Usage Code
from AttentionMechanism.model.attention.GaussianAttention import GCA
import torch
if __name__ == "__main__":
input = torch.randn(16, 64, 32, 32)
gca_layer = GCA(64)
output = gca_layer(input)
print(output.shape)
Acknowledgements
During the development of this project, the following open-source projects provided significant help and support. We hereby express our sincere gratitude:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file attentionmechanism-1.0.2.tar.gz
.
File metadata
- Download URL: attentionmechanism-1.0.2.tar.gz
- Upload date:
- Size: 61.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93bce4749a6f3dab73b9cd49100a9a9f39c592ddcb3085807fdf4991b19735fb |
|
MD5 | 78b602a40eae5e6452b1976f5b7e4433 |
|
BLAKE2b-256 | 1067a6cdd15c47abec71f2cf2f818adad2c82ef68a3f60307e6fc46c003f1d1c |
File details
Details for the file AttentionMechanism-1.0.2-py3-none-any.whl
.
File metadata
- Download URL: AttentionMechanism-1.0.2-py3-none-any.whl
- Upload date:
- Size: 79.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a1245d58e3db24f11f234ea62e2a4036d4b952c827f161f35814e1ac9d36aafe |
|
MD5 | 285ae33bda9934b947392cf46d5fe393 |
|
BLAKE2b-256 | f0d9b5475da0341ef11afd3c7702beb9f2d5c1e11fb612a8ba628535acea3a18 |