A PyTorch implementation of "MetaFormer Baselines" with optional extensions.
Project description
x - Metaformer
A PyTorch implementation of "MetaFormer Baselines" with optional extensions.
We support various self-supervised pretraining approaches such as BarlowTwins,
MoCoV3 or VICReg (see x_metaformer.pretraining
).
Setup
Simply run:
pip install x-metaformer
Example
import torch
from x_metaformer import CAFormer, ConvFormer
my_metaformer = CAFormer(
in_channels=3,
depths=(3, 3, 9, 3),
dims=(64, 128, 320, 512),
init_kernel_size=3,
init_stride=2,
drop_path_rate=0.5,
norm='ln', # ln, bn or rms (layernorm, batchnorm or rmsnorm)
use_pos_emb=True, # use 2d sinusodial positional embeddings
head_dim=32,
num_heads=4,
attn_dropout=0.1,
proj_dropout=0.1,
scale_value=1.0, # scale attention logits by this value
trainable_scale=False, # if scale can be trained
num_mem_vecs=0, # additional memory vectors (in the attention layers)
sparse_topk=0, # sparsify - keep only top k values (in the attention layers)
l2=False, # l2 norm on tokens (in the attention layers)
improve_locality=False, # remove attention on own token
use_starreglu=False # use gated StarReLU
)
x = torch.randn(64, 3, 64, 64) # B C H W
out = my_metaformer(x, return_embeddings=False) # returns average pooled tokens
🚧 Repo is under active development ...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
x-Metaformer-0.1.7.tar.gz
(9.7 kB
view hashes)
Built Distribution
Close
Hashes for x_Metaformer-0.1.7-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b7ab953a378297426d7184c6a0b8b02e8ad408c76932fd6fa5f6801c0ef5d359 |
|
MD5 | 1cdaae1c49eab62f7035a14a595c5254 |
|
BLAKE2b-256 | 7569b67a2621a8be3777385fc29aa553edc397fde2c47260c582396efd58646b |